diff --git a/.clang-format-ignore b/.clang-format-ignore index 94dcaec5c9fe..e24e282f48dc 100644 --- a/.clang-format-ignore +++ b/.clang-format-ignore @@ -12,3 +12,4 @@ ./tutorial # hexagon_remote/bin/src is also special ./src/runtime/hexagon_remote/bin/src +./dependencies/spirv diff --git a/.clang-tidy b/.clang-tidy index 646be4bdbc83..808f93dbed15 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -6,11 +6,14 @@ Checks: > -*, bugprone-*, -bugprone-branch-clone, + -bugprone-easily-swappable-parameters, -bugprone-exception-escape, + -bugprone-implicit-widening-of-multiplication-result, -bugprone-integer-division, -bugprone-narrowing-conversions, -bugprone-reserved-identifier, -bugprone-signed-char-misuse, + clang-diagnostic-shadow-field, misc-*, -misc-no-recursion, -misc-non-private-member-variables-in-classes, @@ -22,7 +25,11 @@ Checks: > modernize-make-unique, modernize-redundant-void-arg, modernize-use-bool-literals, - modernize-use-default-member-init, + # Disabled: there is not consensus on whether the Clang-14 behavior + # of this checker is always desirable or not, and there isn't currently + # a way to revert to the Clang-13 behavior. We may revisit this + # check the next time we examine clang-tidy options. + # modernize-use-default-member-init, modernize-use-emplace, modernize-use-equals-default, modernize-use-equals-delete, @@ -51,7 +58,7 @@ Checks: > WarningsAsErrors: '*' HeaderFilterRegex: '.*' FormatStyle: 'file' -CheckOptions: - - key: modernize-use-default-member-init.UseAssignment - value: 1 +#CheckOptions: +# - key: modernize-use-default-member-init.UseAssignment +# value: 1 ... diff --git a/.github/workflows/packaging.yml b/.github/workflows/packaging.yml deleted file mode 100644 index c971d205447e..000000000000 --- a/.github/workflows/packaging.yml +++ /dev/null @@ -1,53 +0,0 @@ -name: Packaging -on: [ 'pull_request' ] -jobs: - package-ubuntu: - name: Package for Ubuntu - runs-on: ubuntu-20.04 - env: - CMAKE_CXX_COMPILER_LAUNCHER: ccache - CMAKE_C_COMPILER_LAUNCHER: ccache - LLVM_ROOT: /usr/lib/llvm-12 - steps: - - name: Install dependencies - run: | - wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null \ - | gpg --dearmor - | sudo tee /etc/apt/trusted.gpg.d/kitware.gpg >/dev/null - sudo apt-add-repository 'deb https://apt.kitware.com/ubuntu/ focal main' - sudo apt update - sudo apt install cmake ninja-build doxygen ccache - sudo apt install llvm-12-dev liblld-12-dev clang-12 libclang-12-dev libjpeg-dev libpng-dev - sudo apt install lintian dpkg-dev - - name: Check out sources - uses: actions/checkout@v2 - - name: Set up ccache - uses: hendrikmuhs/ccache-action@v1 - - name: Run Ubuntu packaging script - run: ./packaging/ubuntu/package.sh . ubuntu - - name: Upload packages - uses: actions/upload-artifact@v2 - with: - name: packages - path: ubuntu/*.deb - test-ubuntu: - name: Test Ubuntu package - needs: package-ubuntu - runs-on: ubuntu-20.04 - steps: - # Specifically use the CMake version that comes with Ubuntu. - - name: Install dependencies - run: | - sudo apt update - sudo apt install cmake ninja-build libc6-dev-arm64-cross gcc-aarch64-linux-gnu g++-aarch64-linux-gnu qemu-user - - name: Check out sources - uses: actions/checkout@v2 - - name: Download Halide Ubuntu packages - uses: actions/download-artifact@v2 - with: - name: packages - - name: Install Halide Ubuntu packages - run: sudo apt install ./*.deb - - name: Test integration - run: | - cmake -S test/integration -B build - cd build && ctest -j$(nproc) --output-on-failure diff --git a/.github/workflows/presubmit.yml b/.github/workflows/presubmit.yml index b370b2fe4848..f13c5ee80716 100644 --- a/.github/workflows/presubmit.yml +++ b/.github/workflows/presubmit.yml @@ -15,11 +15,11 @@ jobs: runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v2 - - uses: DoozyX/clang-format-lint-action@v0.12 + - uses: DoozyX/clang-format-lint-action@v0.14 with: source: '.' extensions: 'h,c,cpp' - clangFormatVersion: 12 + clangFormatVersion: 14 check_clang_tidy: name: Check clang-tidy runs-on: ubuntu-20.04 @@ -27,13 +27,17 @@ jobs: - uses: actions/checkout@v2 - name: Install clang-tidy run: | + # from apt.llvm.org + # wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - + sudo apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 15CF4D18AF4F7421 + sudo apt-add-repository "deb https://apt.llvm.org/$(lsb_release -sc)/ llvm-toolchain-$(lsb_release -sc)-14 main" sudo apt-get update - sudo apt-get install llvm-12 clang-12 liblld-12-dev libclang-12-dev clang-tidy-12 ninja-build + sudo apt-get install llvm-14 clang-14 liblld-14-dev libclang-14-dev clang-tidy-14 ninja-build - name: Run clang-tidy run: | - export CC=clang-12 - export CXX=clang++-12 - export CLANG_TIDY_LLVM_INSTALL_DIR=/usr/lib/llvm-12 + export CC=clang-14 + export CXX=clang++-14 + export CLANG_TIDY_LLVM_INSTALL_DIR=/usr/lib/llvm-14 ./run-clang-tidy.sh check_cmake_file_lists: name: Check CMake file lists diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml deleted file mode 100644 index f7e1dbea8c97..000000000000 --- a/.github/workflows/test.yml +++ /dev/null @@ -1,495 +0,0 @@ -# TODO (known issues) -# - no GPU tests are attempted (probably not possible) -# - cmake static builds aren't handled yet. -# - arm32 and arm64 is build-only, no testing (qemu is too slow). -# Perhaps some limited testing instead of none? -# - python is built+tested for x86-64 targets only (no arm or 32-bit) -# - apps are skipped for x86-32, arm-32, arm-64 -# -# TODO (stuff that could be usefully added, perhaps) -# - build + test of WASM -# - build + test of PyTorch -# -# TODO (GHA issues) -# - GHA is occasionally flaky and some VMs just fail, but there isn't a way -# to restart just one of the jobs (it's currently all-or-none) - -name: Halide Presubmit Build + Test -on: - workflow_dispatch: - # inputs: - # logLevel: - # description: 'Log level' - # required: true - # default: 'warning' - # tags: - # description: 'Test scenario tags' - - # pull_request: - # # We don't want 'edited' (that's basically just the description, title, etc) - # # We don't want 'review_requested' (that's redundant to the ones below for our purposes) - # types: [opened, synchronize, reopened] - # # TODO: do we want to limit this to certain filetypes? - # # paths: - # # - '**.h' - # # - '**.c' - # # - '**.cpp' - -jobs: - test_halide: - name: HL-${{matrix.llvm_version}}-${{matrix.target_arch}}-${{matrix.target_bits}}-${{matrix.target_os}}-${{matrix.build_tool}} - runs-on: ${{matrix.host_os}} - env: - CC: ${{matrix.cc}} - CXX: ${{matrix.cxx}} - LD: ${{matrix.ld}} - - strategy: - fail-fast: false # Keep running other jobs even if one fails - # free-tier projects (like Halide) get 20 concurrent tasks. - # The build matrix here has only 7 tasks -- should we limit it to fewer - # than that? Need to experiment. - # max-parallel: TBD TODO - matrix: - # TODO: this matrix is probably overkill; we don't need to build every combination. - # (Some combinations are nonsensical and excluded via the 'exclude:' section below.) - target_arch: [x86, arm] - target_bits: [32, 64] - target_os: [linux, osx, windows] - llvm_version: [12] - build_tool: [cmake_shared] - # llvm_version: [10, 11, 12] # TODO - # build_tool: [cmake_shared, make] # TODO - - # This section basically allows us to define additional values for - # each matrix entry, e.g. to map an llvm version number to the specific - # git branch that is needed. - include: - # - llvm_version: 10 - # llvm_branch: release/10.x - # - llvm_version: 11 - # llvm_branch: release/11.x - - llvm_version: 12 - llvm_branch: master - - # map things to the necessary host cross-compiler host - - target_os: osx - host_os: macos-10.15 - cc: clang - cxx: clang++ - ld: ld - - - target_os: linux - host_os: ubuntu-18.04 - # GHA has clang 6, 8, and 9 and GCC 7.4, 8.3, 9.2 preinstalled. - # We will explicitly choose gcc 7.x (even though the default is gcc 7.4) - # to ensure we match gcc versions with the arm crosscompiler. - cc: gcc-7 - cxx: g++-7 - ld: ld - - - target_os: windows - host_os: windows-2019 - cc: cl.exe - cxx: cl.exe - ld: ld.exe - - - target_arch: x86 - python_version: '3.7' - uses_python: true - run_tests: true - - - target_bits: 32 - # We don't build/test Python bindings on any 32-bit targets - uses_python: false - - - target_arch: arm - # We don't build/test Python bindings on any ARM targets - uses_python: false - # Running our test suite (via e.g. QEMU) is too slow to be useful - # at present (> 5 hours on current GHA VMs). That said, we'll leave - # in the relevant code for now (disabled via this flag) in case - # it proves useful later. - run_tests: false - - exclude: - - target_os: osx - target_arch: arm # OSX is x86-only - - target_os: osx - target_bits: 32 # OSX is 64-bit only - - target_os: windows - target_arch: arm # OSX is x86-only - - target_os: windows - build_tool: make # Windows is CMake-only - - target_arch: arm - build_tool: make # In this setup, all ARM builds require CMake - - steps: - - uses: actions/checkout@v2 - with: - path: 'halide' - - - name: Configure Python - if: matrix.uses_python - uses: actions/setup-python@v1 - with: - python-version: '${{matrix.python_version}}' - architecture: 'x64' - - - name: Configure Ubuntu Host - if: startsWith(matrix.host_os, 'ubuntu') - shell: bash - run: | - sudo apt-get update - - sudo apt-get install \ - doxygen \ - libjpeg-dev \ - libpng-dev \ - ninja-build - - - name: Configure MacOS Host - if: startsWith(matrix.host_os, 'macos') - shell: bash - run: | - # coreutils is for gtimeout - brew install \ - coreutils \ - doxygen \ - jpeg \ - libpng \ - ninja - - - name: Configure Windows Host - if: startsWith(matrix.host_os, 'windows') - shell: bash - run: | - if [[ ${{matrix.target_bits}} == 32 ]]; then - export VCPKG_DEFAULT_TRIPLET=x86-windows - else - export VCPKG_DEFAULT_TRIPLET=x64-windows - fi - - vcpkg install \ - libjpeg-turbo \ - libpng \ - zlib - - - name: Configure x86-32 Crosscompilation - if: matrix.target_os == 'linux' && matrix.target_arch == 'x86' && matrix.target_bits == 32 - shell: bash - run: | - sudo dpkg --add-architecture i386 - sudo apt-get update - sudo apt-get install \ - ${{matrix.cc}}-multilib \ - ${{matrix.cxx}}-multilib \ - libjpeg-dev:i386 \ - libpng-dev:i386 \ - - - name: Configure Arm32 Crosscompilation - if: matrix.target_os == 'linux' && matrix.target_arch == 'arm' && matrix.target_bits == 32 - shell: bash - run: | - # Note that we are configuring this for user-mode emulation: - # syscalls will be native, only user-mode code will be emulated. - # This is not 100% perfect (there are various corner cases that - # can bite us), but is *much* faster than full machine emulation. - - sudo apt-get update - sudo apt-get install --install-suggests \ - ${{matrix.cc}}-arm-linux-gnueabihf \ - ${{matrix.cxx}}-arm-linux-gnueabihf - - # TODO: figure out how to install libjpeg and libpng for armhf; - # the standard apt repository for GHA VMs barfs on these. - # sudo apt-get install \ - # libjpeg-dev:armhf \ - # libpng-dev:armhf - - # Note that we need QEMU even if not running tests, as Generators - # will be built for arm by default, and we need to be able to run them. - sudo apt-get install --install-suggests \ - qemu-user \ - qemu-user-binfmt - - qemu-arm --version - echo ::set-env name=QEMU_LD_PREFIX::"/usr/arm-linux-gnueabihf" - - - name: Configure AArch64 Crosscompilation - if: matrix.target_os == 'linux' && matrix.target_arch == 'arm' && matrix.target_bits == 64 - shell: bash - run: | - sudo apt-get update - sudo apt-get install --install-suggests \ - ${{matrix.cc}}-aarch64-linux-gnu \ - ${{matrix.cxx}}-aarch64-linux-gnu - - # TODO: figure out how to install libjpeg and libpng for armhf; - # the standard apt repository for GHA VMs barfs on these. - # sudo apt-get install \ - # libjpeg-dev:aarch64 \ - # libpng-dev:aarch64 - - # Note that we need QEMU even if not running tests, as Generators - # will be built for arm by default, and we need to be able to run them. - sudo apt-get install --install-suggests \ - qemu-user \ - qemu-user-binfmt - - qemu-arm --version - echo ::set-env name=QEMU_LD_PREFIX::"/usr/aarch64-linux-gnu" - - - name: Configure Env Vars - shell: bash - run: | - echo "github.event_name is ${{github.event_name}}" # should always be "pull_request" - echo "github.event.action is ${{github.event.action}}" - - # Demangle Windows names, to simplify CMake stuff later - _ROOT=${GITHUB_WORKSPACE//\\//} - _TEMP_RAW="${{runner.temp}}" - _TEMP=${_TEMP_RAW//\\//} - - # This is the trick GitHub Actions uses to allow us to set env vars across all subsequent job steps - echo ::set-env name=BUILD_TYPE::"Release" - echo ::set-env name=LLVM_INSTALL_DIR::"${_ROOT}/llvm" - echo ::set-env name=LLVM_CONFIG::"${_ROOT}/llvm/bin/llvm-config" - echo ::set-env name=HALIDE_SOURCE_DIR::"${_ROOT}/halide" - echo ::set-env name=HALIDE_BUILD_DIR::"${_ROOT}/halide_build" - echo ::set-env name=HALIDE_TEMP_DIR::"${_TEMP}" - echo ::set-env name=PARALLEL_JOBS::"4" - if [[ ${{matrix.host_os}} == windows* ]]; then - # On Windows, it's just 'python', apparently - echo ::set-env name=PYTHON::"python" - else - echo ::set-env name=PYTHON::"python${{matrix.python_version}}" - fi - - - name: Install Python Dependencies - if: matrix.uses_python - shell: bash - run: | - ${PYTHON} -m pip --version - ${PYTHON} -m pip install --upgrade pip - ${PYTHON} -m pip install -r ${HALIDE_SOURCE_DIR}/python_bindings/requirements.txt - - echo ::set-env name=PYTHON::"${PYTHON}" - - - name: Install LLVM - shell: bash - run: | - LLVM_ID="llvm-${{matrix.llvm_version}}-${{matrix.target_arch}}-${{matrix.target_bits}}-${{matrix.target_os}}" - - curl \ - --user llvm_user:${{secrets.LLVM_USER_PASSWORD}} \ - --output ${HALIDE_TEMP_DIR}/llvm-prebuilt.tgz \ - https://buildbot.halide-lang.org/llvm/${LLVM_ID}.tgz - - TAR_CMD="tar" - if [[ ${{matrix.host_os}} == windows* ]]; then - # Must use --force-local to avoid tar misinterpreting the : in - # a Windows pathname as a hostname. - TAR_CMD="tar --force-local" - fi - - mkdir ${LLVM_INSTALL_DIR} - ${TAR_CMD} -xf ${HALIDE_TEMP_DIR}/llvm-prebuilt.tgz -C ${LLVM_INSTALL_DIR} - rm -rf ${HALIDE_TEMP_DIR}/llvm-prebuilt.tgz - - LLVM_COMMIT_HASH=`cat ${LLVM_INSTALL_DIR}/.halide_builder_llvm_commit` - echo "Using LLVM v${{matrix.llvm_version}} commit=${LLVM_COMMIT_HASH}" - - - name: Configure Halide (Make) - if: startsWith(matrix.build_tool, 'make') - shell: bash - run: | - # Configure Make - mkdir ${HALIDE_BUILD_DIR} - - if [[ ${{matrix.target_arch}} == x86 && \ - ${{matrix.target_os}} == linux && \ - ${{matrix.target_bits}} == 32 ]]; then - echo ::set-env name=CC::"${CC} -m32" - echo ::set-env name=CXX::"${CXX} -m32" - echo ::set-env name=LD::"${LD} -melf_i386" - fi - - - name: Configure Halide (CMake) - if: startsWith(matrix.build_tool, 'cmake') - shell: bash - run: | - # Configure CMake - echo `cmake --version` - - mkdir ${HALIDE_BUILD_DIR} - - CMAKE_GEN="Ninja" - EXTRA_CMAKE_FLAGS= - - if [[ ${{matrix.host_os}} == windows* ]]; then - CMAKE_GEN="Visual Studio 16 2019" - - # CMAKE_TOOLCHAIN_FILE is necessary for CMake to find things installed by vcpkg - EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} \ - -D CMAKE_TOOLCHAIN_FILE=${VCPKG_INSTALLATION_ROOT//\\//}/scripts/buildsystems/vcpkg.cmake \ - -T host=x64" - if [[ ${{matrix.target_bits}} == 32 ]]; then - EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} -A Win32" - else - EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} -A x64" - fi - fi - - if [[ ${{matrix.target_arch}} == x86 && \ - ${{matrix.target_os}} == linux && \ - ${{matrix.target_bits}} == 32 ]]; then - # Assume host_os is ubuntu* - EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} \ - -D CMAKE_TOOLCHAIN_FILE=${HALIDE_SOURCE_DIR}/cmake/toolchain.linux-i386.cmake" - fi - - if [[ ${{matrix.target_os}} == osx ]]; then - # LLVM_ENABLE_SUPPORT_XCODE_SIGNPOSTS=OFF is needed for compatibility with older XCode versions - EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} \ - -D LLVM_ENABLE_SUPPORT_XCODE_SIGNPOSTS=FORCE_OFF" - fi - - if [[ ${{matrix.target_arch}} == arm ]]; then - # The arm toolchain files default to "gcc"/"g++" with no version appended, - # but we installed specific versions, so be sure it can find those specific versions. - if [[ ${{matrix.target_bits}} == 32 ]]; then - export ARCH_FOR_TESTS=armv7-a - EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} \ - -D CMAKE_C_COMPILER=arm-linux-gnueabihf-${{matrix.cc}} \ - -D CMAKE_CXX_COMPILER=arm-linux-gnueabihf-${{matrix.cxx}} \ - -D CMAKE_TOOLCHAIN_FILE=${HALIDE_SOURCE_DIR}/cmake/toolchain.linux-arm32.cmake" - else - export ARCH_FOR_TESTS=armv8-a - EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} \ - -D CMAKE_C_COMPILER=aarch64-linux-gnu-${{matrix.cc}} \ - -D CMAKE_CXX_COMPILER=aarch64-linux-gnu-${{matrix.cxx}} \ - -D CMAKE_TOOLCHAIN_FILE=${HALIDE_SOURCE_DIR}/cmake/toolchain.linux-aarch64.cmake" - fi - fi - - REQUIRE_LLVM_VERSION="${{matrix.llvm_version}}0" - SHARED_LIBRARY=$([ ${{matrix.build_tool}} == "cmake_shared" ] && echo "ON" || echo "OFF") - - if [[ "${{matrix.uses_python}}" == "true" ]]; then - EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} \ - -D WITH_PYTHON_BINDINGS=ON" - else - EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} \ - -D WITH_PYTHON_BINDINGS=OFF" - fi - - cmake \ - -D CMAKE_BUILD_TYPE=${BUILD_TYPE} \ - -D LLVM_DIR="${LLVM_INSTALL_DIR}/lib/cmake/llvm" \ - -D HALIDE_REQUIRE_LLVM_VERSION="${REQUIRE_LLVM_VERSION}" \ - -D HALIDE_SHARED_LIBRARY=${SHARED_LIBRARY} \ - -G "${CMAKE_GEN}" \ - ${EXTRA_CMAKE_FLAGS} \ - -S "${HALIDE_SOURCE_DIR}" \ - -B "${HALIDE_BUILD_DIR}" - - - name: Build Halide (Make) - if: startsWith(matrix.build_tool, 'make') - shell: bash - run: | - # Build Halide - cd ${HALIDE_BUILD_DIR} - - BUILD_TARGETS="distrib build_tests" - if [[ "${{matrix.uses_python}}" == "true" ]]; then - # build_apps requires the python bindings - BUILD_TARGETS="${BUILD_TARGETS} build_apps build_python_bindings" - fi - - make -f ${HALIDE_SOURCE_DIR}/Makefile -j ${PARALLEL_JOBS} ${BUILD_TARGETS} - - - name: Build Halide (CMake) - if: startsWith(matrix.build_tool, 'cmake') - shell: bash - run: | - # Build Halide - cd ${HALIDE_BUILD_DIR} - cmake \ - --build ${HALIDE_BUILD_DIR} \ - --config ${BUILD_TYPE} \ - -j ${PARALLEL_JOBS} - - - name: Run Tests (Make) - if: matrix.run_tests && startsWith(matrix.build_tool, 'make') - shell: bash - run: | - # Test Halide - export TEST_TMPDIR="${HALIDE_TEMP_DIR}" - cd ${HALIDE_BUILD_DIR} - - TEST_GROUPS_PARALLEL="internal correctness error warning generator" - if [[ "${{matrix.uses_python}}" == "true" ]]; then - TEST_GROUPS_PARALLEL="${TEST_GROUPS_PARALLEL} python" - fi - - # tutorial has some performance measurements that can be flaky if we run them in parallel - TEST_GROUPS_SERIAL="tutorial" - - # performance is never going to be reliable on VMs. - # auto_schedule is just flaky. - TEST_GROUPS_BROKEN="performance auto_schedule" - - if [[ ${{matrix.target_bits}} == 32 ]]; then - # TODO: Skip testing apps on 32-bit systems for now; - # in particular, apps/autoscheduler can time out, and also has build - # issues on ubuntu-32 at the moment (__udivdi3). - TEST_GROUPS_BROKEN="${TEST_GROUPS_BROKEN} apps" - else - TEST_GROUPS_PARALLEL="${TEST_GROUPS_PARALLEL} apps" - fi - - # Parallel - for t in ${TEST_GROUPS_PARALLEL}; do - make -f ${HALIDE_SOURCE_DIR}/Makefile -j ${PARALLEL_JOBS} test_${t} - done - - # Serial - for t in ${TEST_GROUPS_SERIAL}; do - make -f ${HALIDE_SOURCE_DIR}/Makefile test_$t - done - - - name: Run Tests (CMake) - if: matrix.run_tests && startsWith(matrix.build_tool, 'cmake') - shell: bash - run: | - # Test Halide - TEST_GROUPS_PARALLEL="internal|correctness|error|warning|generator" - - if [[ "${{matrix.uses_python}}" == "true" ]]; then - TEST_GROUPS_PARALLEL="${TEST_GROUPS_PARALLEL}|python" - fi - - # tutorial has some performance measurements that can be flaky if we run them in parallel - TEST_GROUPS_SERIAL="tutorial" - - # performance is never going to be reliable on VMs. - # auto_schedule is just flaky. - TEST_GROUPS_BROKEN="performance|auto_schedule" - - export TEST_TMPDIR="${HALIDE_TEMP_DIR}" - cd ${HALIDE_BUILD_DIR} - - # Parallel - ctest \ - -C ${BUILD_TYPE} \ - -j ${PARALLEL_JOBS} \ - -L "${TEST_GROUPS_PARALLEL}" \ - --output-on-failure - - # Serial - ctest \ - -C ${BUILD_TYPE} \ - -L "${TEST_GROUPS_SERIAL}" \ - -E "${TEST_GROUPS_BROKEN}" \ - --output-on-failure diff --git a/.gitignore b/.gitignore index 286672e22d2a..958a6c93186f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,97 +1,270 @@ -CMakeUserPresets.json +# NOTE: one can debug these rules with the following commands: +# +# $ git clean -ffdx +# $ find . -not -path './.git/*' | git check-ignore --stdin --no-index +# +# The first command will delete all files that are ignored by Git (be warned!). +# The second command will print all files that are checked in, but _would be_ +# ignored under the rules in this file. Such files should either be explicitly +# added to the exclusions at the bottom of this file, or the rule excluding them +# should be refined. -/tutorial/figures/tmp/trace.bin -/apps/*/bin -/apps/*/cmake_build -/apps/HelloMatlab/blurred.png -/apps/HelloMatlab/iir_blur.mex -bin/* -build/* -share/* -python_bindings/bin/* -build-64/* -build-ios/* -build-osx/* -build_wasm*/* -cmake_build*/* -*/build/* -tmp/* -include/* -distrib/* -testing/* -msvc/*/Win32/* -msvc/*/x64/* -\#*\# -.\#* -*~ -*.o +################################################################################ +## Exclude files without extensions + +* +!*.* +!*/ + +################################################################################ +## Halide-specific exclusions + +# Images only allowed in apps and directories named "images" +*.png +!apps/**/*.png +!**/images/**/*.png + +# Pre-trained weights only allowed in autoscheduler directories +*.weights +!src/autoschedulers/**/*.weights + +################################################################################ +## Halide-specific build artifacts + +# Apps +apps/*/*.def +apps/*/*.ptx +apps/*/*.sass +apps/*/*out*.png +apps/*/filter +apps/*/passes.txt +apps/HelloAndroidGL/jni/halide_gl_filter.h + +# Autoschedulers +**/src/autoschedulers/adams2019/baseline.cpp +**/src/autoschedulers/adams2019/cost_model.h +**/src/autoschedulers/adams2019/demo.h +**/src/autoschedulers/adams2019/included_schedule_file.h +**/src/autoschedulers/adams2019/train_cost_model.h +**/src/autoschedulers/li2018/demo_gradient.h + +# CMake configuration +Halide-*-deps.cmake + +# Distribution headers +**/include/Halide*.h +**/include/wasm-rt*.h + +# Generator executables +*.generator + +# Generator outputs +*.bc +*.featurization +*.halide_compiler_log +*.halide_generated.cpp +*.ll +*.py.cpp +*.pytorch.h +*.registration.cpp +*.s +*.schedule.h +*.stmt +*.stmt.html +*.stub.h + +# Linker scripts +py_*.ldscript* + +# Runtime modules +_initmod*.cpp + +# Tests +**/python_bindings/correctness/generators/*.h +**/test/generator/*.h +**/test/generator/external_code_extern_bitcode_32.cpp +**/test/generator/external_code_extern_bitcode_64.cpp +**/test/generator/external_code_extern_cpp_source.cpp +compile_log.txt +stderr.txt +stdout.txt + +# Tutorials +**/tutorial/auto_schedule_false.h +**/tutorial/auto_schedule_true.h +**/tutorial/brighten_either.h +**/tutorial/brighten_interleaved.h +**/tutorial/brighten_planar.h +**/tutorial/brighten_specialized.h +**/tutorial/lesson_10_halide.h +**/tutorial/my_first_generator_win32.h +**/tutorial/my_first_generator.h +**/tutorial/my_second_generator_1.h +**/tutorial/my_second_generator_2.h +**/tutorial/my_second_generator_3.h + +# Tutorial images that were copied to the install tree +**/tutorial/images/ +!tutorial/images/ + +################################################################################ +## Common build artifacts + +# Directories +bin/ +distrib/ +lib/ +lib64/ +share/ + +# Binaries *.a +*.cubin +*.dll +*.dylib +*.exe +*.lib +*.o +*.obj *.so -*.dot -.DS_Store -*.log -generated.obj -hello-fimage -test.s -testX64 -xcuserdata -in.png -_build +*.so.* a.out -*.bc -*.cubin -tags -src/*.top -llvm_svn -llvm/* -cpp/* -*.h.gch -src/test_*.ll -src/test_*.s -.clang_complete -*.guru + +# Compiler intermediates / debugging info +*.[ip]db +*.[pg]ch +*.d *.dSYM -.*.swp -tools/objc/BUILD -tools/objc/*.mobileprovision -*.xcworkspacedata +# Package files +*.deb +*.tar.gz +*.tgz +*.zip + +################################################################################ +## Temporary and swap files + +temp/ +tmp/ +.*.swp +.\#* +.DS_Store +*.log +*.tmp *.txt.user* +*~ +\#*\# + +################################################################################ +## Python + +# Common virtual environment directory names +.venv/ +venv/ + +# Python binary caches +__pycache__ +*.py[cod] + +# Python package build artifacts +*.egg-info/ +*.whl +MANIFEST.in + +################################################################################ +## CMake + +# User-specific configuration files +CMakeUserPresets.json + +# Common build directory names +build*/ +cmake[-_]build*/ + +# Generated config files +*-config-version.cmake +*-config.cmake +*Config.cmake +*ConfigVersion.cmake + +# Build directory contents +_deps/ +.cmake/ +cmake_install.cmake +CMakeCache.txt +CMakeFiles/ +compile_commands.json +CPack*.cmake +CTest*.cmake +CTest*.txt +install_manifest.txt + +# Ninja files +*.ninja* + +################################################################################ +## IDE directories and metadata + +# Visual Studio +.vs/ +out/ + +CMakeSettings.json + +# XCode +*.xcworkspacedata +tools/objc/*.mobileprovision +tools/objc/BUILD +xcuserdata + +# CLion .idea/ -# jrk editor settings +# VSCode +.vscode/ + +# TextMate .tm_properties -*.sublime-project -*.sublime-workspace -apps/patchmatch +# Sublime Text +.tags +.tags_sorted_by_file +*.sublime-* -# app intermediates -apps/*/*out*.png -apps/*/*.lowered -apps/*/*.def -apps/*/*.ptx -apps/*/passes.txt -apps/*/*.ll -apps/*/*.sass -apps/*/filter -apps/HelloAndroidGL/jni/halide_gl_filter.h +# Vim +.clang_complete + +# Emacs +tags +TAGS + +################################################################################ +## Halide-specific rule overrides + +# Allow particular extension-less files +!gradlew +!Makefile +!packaging/ubuntu/changelog +!packaging/ubuntu/copyright +!packaging/ubuntu/triggers + +# Allow XCode PCHs in the HelloiOS app +!apps/HelloiOS/**/*-Prefix.pch + +# Allow the runtime to have handwritten LLVM modules +!src/runtime/*.ll -# tutorial intermediates -tutorial/lesson_10_halide.h +# Allow precompiled Nvidia bitcode +!src/runtime/nvidia_libdevice_bitcode/*.bc -# test intermediates -log.txt -err.txt -test/*.lowered -*.pyc +# Anything goes in the hexagon_remote binaries +!src/runtime/hexagon_remote/**/* -src/.tags -src/.tags_sorted_by_file +# TODO: should this be checked in? +!src/autoschedulers/adams2019/included_schedule_file.schedule.h -/.vs -/out -/CMakeSettings.json -/venv/ -/cmake-build-*/ +# TODO: these should become .cmake.in +!packaging/common/HalideConfig.cmake +!packaging/common/HalideHelpersConfig.cmake diff --git a/CMakeLists.txt b/CMakeLists.txt index 55f279589995..b12b6c7cebf4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,37 +1,22 @@ -cmake_minimum_required(VERSION 3.16...3.20) +cmake_minimum_required(VERSION 3.22...3.23) project(Halide - VERSION 14.0.0 + VERSION 15.0.0 DESCRIPTION "Halide compiler and libraries" HOMEPAGE_URL "https://halide-lang.org") enable_testing() -find_package(Python3 REQUIRED COMPONENTS Interpreter Development) - -# Configure pybind11 to use the same interpreter version as was detected above. -message(STATUS "Directing pybind11 to Python3 executable ${Python3_EXECUTABLE}") -set(PYTHON_EXECUTABLE ${Python3_EXECUTABLE}) - -# Keep the version in sync with requirements.txt and the Ubuntu 20.04 LTS package (python3-pybind11) -set(PYBIND11_VER 2.6.1) -find_package(pybind11 ${PYBIND11_VER} QUIET) -if (NOT pybind11_FOUND) - include(FetchContent) - FetchContent_Declare(pybind11 - GIT_REPOSITORY https://github.com/pybind/pybind11.git - GIT_TAG v${PYBIND11_VER}) - FetchContent_MakeAvailable(pybind11) -endif () - ## # Set up project-wide properties ## +# Import useful standard modules +include(CMakeDependentOption) +include(CheckCXXSymbolExists) + # Make our custom helpers available throughout the project via include(). list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR}/cmake) include(HalideGeneratorHelpers) -include(MakeShellPath) -include(CMakeDependentOption) # Build Halide as a shared lib by default, but still honor command-line settings. option(BUILD_SHARED_LIBS "Build shared libraries" ON) @@ -67,24 +52,35 @@ endif() # Build Halide with ccache if the package is present option(Halide_CCACHE_BUILD "Set to ON for a ccache enabled build" OFF) mark_as_advanced(Halide_CCACHE_BUILD) + if (Halide_CCACHE_BUILD) - find_program(CCACHE_PROGRAM ccache) - if (CCACHE_PROGRAM) - # TODO: ccache recommends setting CCACHE_SLOPPINESS=pch_defines,time_macros to - # enable precompiled header caching. Our timing found it slightly faster with - # just CCACHE_SLOPPINESS=pch_defines, so that's what we're using. Maybe revisit - # if issues occur (but we don't use any of the time macros so should be irrelevant). - set(Halide_CCACHE_PARAMS CCACHE_CPP2=yes CCACHE_HASHDIR=yes CCACHE_SLOPPINESS=pch_defines - CACHE STRING "Parameters to pass through to ccache") - mark_as_advanced(Halide_CCACHE_PARAMS) - set(CMAKE_C_COMPILER_LAUNCHER ${CMAKE_COMMAND} -E env ${Halide_CCACHE_PARAMS} ${CCACHE_PROGRAM}) - set(CMAKE_CXX_COMPILER_LAUNCHER ${CMAKE_COMMAND} -E env ${Halide_CCACHE_PARAMS} ${CCACHE_PROGRAM}) - message(STATUS "Enabling ccache usage for building.") - else () - message(FATAL_ERROR "Unable to find the program ccache. Set Halide_CCACHE_BUILD to OFF") - endif () + find_program(CCACHE_PROGRAM ccache REQUIRED) + + # TODO: ccache recommends setting CCACHE_SLOPPINESS=pch_defines,time_macros to + # enable precompiled header caching. Our timing found it slightly faster with + # just CCACHE_SLOPPINESS=pch_defines, so that's what we're using. Maybe revisit + # if issues occur (but we don't use any of the time macros so should be irrelevant). + set(Halide_CCACHE_PARAMS CCACHE_CPP2=yes CCACHE_HASHDIR=yes CCACHE_SLOPPINESS=pch_defines + CACHE STRING "Parameters to pass through to ccache") + mark_as_advanced(Halide_CCACHE_PARAMS) + + set(CMAKE_C_COMPILER_LAUNCHER ${CMAKE_COMMAND} -E env ${Halide_CCACHE_PARAMS} ${CCACHE_PROGRAM}) + set(CMAKE_CXX_COMPILER_LAUNCHER ${CMAKE_COMMAND} -E env ${Halide_CCACHE_PARAMS} ${CCACHE_PROGRAM}) + + message(STATUS "Enabling ccache usage for building.") endif () +# Detect whether or not ASAN is enabled +check_cxx_symbol_exists(HALIDE_INTERNAL_USING_ASAN "${Halide_SOURCE_DIR}/src/Util.h" Halide_ASAN_ENABLED) +if (Halide_ASAN_ENABLED) + set(Halide_ANY_SANITIZERS_ENABLED 1) +else () + set(Halide_ANY_SANITIZERS_ENABLED 0) +endif () + +# Enable the SPIR-V target if requested (must declare before processing dependencies) +option(TARGET_SPIRV "Include SPIR-V target" OFF) + ## # Import dependencies ## @@ -111,20 +107,13 @@ if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) message(STATUS "Building tests disabled") endif () - option(WITH_PYTHON_BINDINGS "Build Python bindings" ON) + cmake_dependent_option( + WITH_PYTHON_BINDINGS "Build Python bindings" ON + "Halide_ENABLE_RTTI;Halide_ENABLE_EXCEPTIONS" OFF + ) if (WITH_PYTHON_BINDINGS) - if (Halide_ENABLE_RTTI AND Halide_ENABLE_EXCEPTIONS) - message(STATUS "Building Python bindings enabled") - add_subdirectory(python_bindings) - else () - if (NOT Halide_ENABLE_RTTI) - message(WARNING "Building Python bindings disabled: must compile with RTTI") - endif () - if (NOT Halide_ENABLE_EXCEPTIONS) - message(WARNING "Building Python bindings disabled: must compile with exceptions") - endif () - set(WITH_PYTHON_BINDINGS OFF CACHE BOOL "Build Python bindings" FORCE) - endif () + message(STATUS "Building Python bindings enabled") + add_subdirectory(python_bindings) else () message(STATUS "Building Python bindings disabled") endif () diff --git a/CMakePresets.json b/CMakePresets.json index 3d20fd0aa27a..7fa0ec43b2f0 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -1,73 +1,105 @@ { - "version": 1, + "version": 3, "cmakeMinimumRequired": { "major": 3, - "minor": 16, + "minor": 22, "patch": 0 }, "configurePresets": [ { - "name": "gcc-debug", - "displayName": "GCC (Debug)", - "description": "Debug build using Ninja generator and GCC-compatible compiler", - "generator": "Ninja", - "binaryDir": "${sourceDir}/build", + "name": "base", + "hidden": true, + "binaryDir": "build/${presetName}", + "installDir": "install/${presetName}" + }, + { + "name": "ci", + "hidden": true, + "inherits": "base", + "toolchainFile": "${sourceDir}/cmake/toolchain.${presetName}.cmake", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "RelWithDebInfo" + } + }, + { + "name": "windows-only", + "hidden": true, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + } + }, + { + "name": "vcpkg", + "hidden": true, + "toolchainFile": "$env{VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" + }, + { + "name": "vs2019", + "hidden": true, + "inherits": [ + "vcpkg", + "windows-only" + ], + "generator": "Visual Studio 16 2019", + "toolset": "host=x64" + }, + { + "name": "debug", + "inherits": "base", + "displayName": "Debug", + "description": "Debug build with no special settings", "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" } }, { - "name": "gcc-release", - "inherits": "gcc-debug", - "displayName": "GCC (Release)", - "description": "Release build using Ninja generator and GCC-compatible compiler", + "name": "release", + "inherits": "base", + "displayName": "Release", + "description": "Release build with no special settings", "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } }, { - "name": "msvc-debug", - "displayName": "MSVC (Debug)", - "description": "Debug build using Ninja generator and MSVC with vcpkg dependencies.", - "generator": "Ninja", - "binaryDir": "${sourceDir}/build", + "name": "debian-debug", + "inherits": "debug", + "displayName": "Debian (Debug)", + "description": "Debug build assuming Debian-provided dependencies", "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_TOOLCHAIN_FILE": "$env{VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" + "Halide_SHARED_LLVM": "ON" } }, { - "name": "msvc-release", - "inherits": "msvc-debug", - "displayName": "MSVC (Release)", - "description": "Release build using Ninja generator and MSVC with vcpkg dependencies.", + "name": "debian-release", + "inherits": "debian-debug", + "displayName": "Debian (Release)", + "description": "Release build assuming Debian-provided dependencies", "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } }, { "name": "win32", + "inherits": [ + "vs2019", + "base" + ], "displayName": "Win32 (Visual Studio)", "description": "Visual Studio-based Win32 build with vcpkg dependencies.", - "generator": "Visual Studio 16 2019", - "architecture": "Win32", - "toolset": "host=x64", - "binaryDir": "${sourceDir}/build", - "cacheVariables": { - "CMAKE_TOOLCHAIN_FILE": "$env{VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" - } + "architecture": "Win32" }, { "name": "win64", + "inherits": [ + "vs2019", + "base" + ], "displayName": "Win64 (Visual Studio)", - "description": "Visual Studio-based Win64 build with vcpkg dependencies.", - "generator": "Visual Studio 16 2019", - "architecture": "x64", - "toolset": "host=x64", - "binaryDir": "${sourceDir}/build", - "cacheVariables": { - "CMAKE_TOOLCHAIN_FILE": "$env{VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" - } + "description": "Visual Studio-based x64 build with vcpkg dependencies.", + "architecture": "x64" }, { "name": "package", @@ -87,14 +119,14 @@ }, { "name": "package-windows", - "inherits": "package", + "inherits": [ + "package", + "vs2019" + ], "displayName": "Package ZIP for Windows", "description": "Build for packaging Windows shared libraries.", "binaryDir": "${sourceDir}/build", - "generator": "Visual Studio 16 2019", - "toolset": "host=x64", "cacheVariables": { - "CMAKE_TOOLCHAIN_FILE": "$env{VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake", "BUILD_SHARED_LIBS": "YES", "CMAKE_INSTALL_BINDIR": "bin/$", "CMAKE_INSTALL_LIBDIR": "lib/$", @@ -102,15 +134,9 @@ "Halide_INSTALL_HELPERSDIR": "lib/cmake/HalideHelpers" } }, - { - "name": "package-unix", - "hidden": true, - "inherits": "package", - "generator": "Ninja" - }, { "name": "package-unix-shared", - "inherits": "package-unix", + "inherits": "package", "displayName": "Package UNIX shared libs", "description": "Build for packaging UNIX shared libraries.", "binaryDir": "shared-Release", @@ -120,7 +146,7 @@ }, { "name": "package-unix-static", - "inherits": "package-unix", + "inherits": "package", "displayName": "Package UNIX static libs", "description": "Build for packaging UNIX static libraries.", "binaryDir": "static-Release", @@ -130,32 +156,64 @@ } }, { - "name": "package-ubuntu-shared", - "inherits": "package-unix-shared", - "displayName": "Package shared Halide for Ubuntu", - "description": "Package shared Halide for Ubuntu, using system packages.", - "binaryDir": "shared-release", + "name": "linux-x64-asan", + "inherits": "ci", + "displayName": "ASAN (Linux x64)", + "description": "Build everything with ASAN enabled", "cacheVariables": { - "Halide_SHARED_LLVM": "YES", - "LLVM_DIR": "$env{LLVM_ROOT}/lib/cmake/llvm", - "Clang_DIR": "$env{LLVM_ROOT}/lib/cmake/clang", - "LLD_DIR": "$env{LLVM_ROOT}/lib/cmake/lld", - "CMAKE_INSTALL_INCLUDEDIR": "include/Halide", - "CMAKE_INSTALL_LIBDIR": "lib/x86_64-linux-gnu", - "Halide_INSTALL_PLUGINDIR": "lib/x86_64-linux-gnu/Halide", - "Halide_INSTALL_HELPERSDIR": "lib/cmake/HalideHelpers", - "CMAKE_STRIP": "${sourceDir}/packaging/ubuntu/extra-strip.sh" + "LLVM_ROOT": "$penv{LLVM_ROOT}" } + } + ], + "buildPresets": [ + { + "name": "debug", + "configurePreset": "debug", + "displayName": "Debug", + "description": "Debug build with no special settings" }, { - "name": "package-ubuntu-static", - "inherits": "package-ubuntu-shared", - "displayName": "Package static Halide for Ubuntu", - "description": "Package static Halide for Ubuntu, using system packages.", - "binaryDir": "static-release", - "cacheVariables": { - "BUILD_SHARED_LIBS": "NO", - "WITH_DOCS": "NO" + "name": "release", + "configurePreset": "release", + "displayName": "Release", + "description": "Release build with no special settings" + }, + { + "name": "linux-x64-asan", + "configurePreset": "linux-x64-asan", + "displayName": "ASAN (Linux x64)", + "description": "Build everything with ASAN enabled" + } + ], + "testPresets": [ + { + "name": "debug", + "configurePreset": "debug", + "displayName": "Debug", + "description": "Test everything with Debug build", + "output": { + "outputOnFailure": true + } + }, + { + "name": "release", + "configurePreset": "release", + "displayName": "Release", + "description": "Test everything with Release build", + "output": { + "outputOnFailure": true + } + }, + { + "name": "linux-x64-asan", + "configurePreset": "linux-x64-asan", + "displayName": "ASAN (Linux x64)", + "description": "Test everything with ASAN enabled", + "environment": { + "ASAN_OPTIONS": "detect_leaks=0:detect_container_overflow=0" + }, + "output": { + "outputOnFailure": true } } ] diff --git a/LICENSE.txt b/LICENSE.txt index a9b9ca6b7d48..13146db88f3b 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -27,15 +27,120 @@ SOFTWARE. apps/bgu is Copyright 2016 Google Inc. and is Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -with the License. You may obtain a copy of the License at +with the License. -http ://www.apache.org/licenses/LICENSE-2.0 +Apache License -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +Version 2.0, January 2004 + +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and +distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the +copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other +entities that control, are controlled by, or are under common control with +that entity. For the purposes of this definition, "control" means (i) the +power, direct or indirect, to cause the direction or management of such +entity, whether by contract or otherwise, or (ii) ownership of fifty percent +(50%) or more of the outstanding shares, or (iii) beneficial ownership of such +entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising +permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, +including but not limited to software source code, documentation source, and +configuration files. + +"Object" form shall mean any form resulting from mechanical transformation or +translation of a Source form, including but not limited to compiled object +code, generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, +made available under the License, as indicated by a copyright notice that is +included in or attached to the work (an example is provided in the Appendix +below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that +is based on (or derived from) the Work and for which the editorial revisions, +annotations, elaborations, or other modifications represent, as a whole, an +original work of authorship. For the purposes of this License, Derivative +Works shall not include works that remain separable from, or merely link (or +bind by name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original +version of the Work and any modifications or additions to that Work or +Derivative Works thereof, that is intentionally submitted to Licensor for +inclusion in the Work by the copyright owner or by an individual or Legal +Entity authorized to submit on behalf of the copyright owner. For the purposes +of this definition, "submitted" means any form of electronic, verbal, or +written communication sent to the Licensor or its representatives, including +but not limited to communication on electronic mailing lists, source code +control systems, and issue tracking systems that are managed by, or on behalf +of, the Licensor for the purpose of discussing and improving the Work, but +excluding communication that is conspicuously marked or otherwise designated +in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf +of whom a Contribution has been received by Licensor and subsequently +incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + +(a) You must give any other recipients of the Work or Derivative Works a copy +of this License; and + +(b) You must cause any modified files to carry prominent notices stating that +You changed the files; and + +(c) You must retain, in the Source form of any Derivative Works that You +distribute, all copyright, patent, trademark, and attribution notices from the +Source form of the Work, excluding those notices that do not pertain to any +part of the Derivative Works; and + +(d) If the Work includes a "NOTICE" text file as part of its distribution, +then any Derivative Works that You distribute must include a readable copy of +the attribution notices contained within such NOTICE file, excluding those +notices that do not pertain to any part of the Derivative Works, in at least +one of the following places: within a NOTICE text file distributed as part of +the Derivative Works; within the Source form or documentation, if provided +along with the Derivative Works; or, within a display generated by the +Derivative Works, if and wherever such third-party notices normally appear. +The contents of the NOTICE file are for informational purposes only and do not +modify the License. You may add Your own attribution notices within Derivative +Works that You distribute, alongside or as an addendum to the NOTICE text from +the Work, provided that such additional attribution notices cannot be +construed as modifying the License. + +You may add Your own copyright statement to Your modifications and may provide +additional or different license terms and conditions for use, reproduction, or +distribution of Your modifications, or for any such Derivative Works as a +whole, provided Your use, reproduction, and distribution of the Work otherwise +complies with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS ----- @@ -63,3 +168,49 @@ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +---- + +dependencies/spirv is Copyright (c) 2014-2018 The Khronos Group Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and/or associated documentation files (the "Materials"), +to deal in the Materials without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Materials, and to permit persons to whom the +Materials are furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Materials. + +MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS KHRONOS +STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS SPECIFICATIONS AND +HEADER INFORMATION ARE LOCATED AT https://www.khronos.org/registry/ + +THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM,OUT OF OR IN CONNECTION WITH THE MATERIALS OR THE USE OR OTHER DEALINGS +IN THE MATERIALS. + +---- + +apps/linear_algebra/include/cblas.h is licensed under the BLAS license. + +The reference BLAS is a freely-available software package. It is available from +netlib via anonymous ftp and the World Wide Web. Thus, it can be included in +commercial software packages (and has been). We only ask that proper credit be +given to the authors. + +Like all software, it is copyrighted. It is not trademarked, but we do ask the +following: + +If you modify the source for these routines we ask that you change the name of +the routine and comment the changes made to the original. + +We will gladly answer any questions regarding the software. If a modification is +done, however, it is the responsibility of the person who modified the routine +to provide support. + diff --git a/Makefile b/Makefile index 3d59cdaf5770..b32e249312dc 100644 --- a/Makefile +++ b/Makefile @@ -29,6 +29,11 @@ else endif endif +# We want to build Halide plugins as .so on all posixy systems, including OSX. +# This is called out as a named var to make it clear that the use +# is deliberate, not an accident. +PLUGIN_EXT=so + ifeq ($(UNAME), Darwin) # Anything that we us install_name_tool on needs these linker flags # to ensure there is enough padding for install_name_tool to use @@ -67,8 +72,6 @@ LLVM_CXX_FLAGS = -std=c++17 $(filter-out -O% -g -fomit-frame-pointer -pedantic OPTIMIZE ?= -O3 OPTIMIZE_FOR_BUILD_TIME ?= -O0 -PYTHON ?= python3 - CLANG ?= $(LLVM_BINDIR)/clang CLANG_VERSION = $(shell $(CLANG) --version) @@ -265,6 +268,9 @@ TEST_LD_FLAGS = -L$(BIN_DIR) -lHalide $(COMMON_LD_FLAGS) # In the tests, some of our expectations change depending on the llvm version TEST_CXX_FLAGS += -DLLVM_VERSION=$(LLVM_VERSION_TIMES_10) +# In the tests, default to exporting no symbols that aren't explicitly exported +TEST_CXX_FLAGS += -fvisibility=hidden -fvisibility-inlines-hidden + # gcc 4.8 fires a bogus warning on old versions of png.h ifneq (,$(findstring g++,$(CXX_VERSION))) ifneq (,$(findstring 4.8,$(CXX_VERSION))) @@ -399,6 +405,7 @@ HEXAGON_RUNTIME_LIBS = \ # Keep this list sorted in alphabetical order. SOURCE_FILES = \ + AbstractGenerator.cpp \ AddAtomicMutex.cpp \ AddImageChecks.cpp \ AddParameterChecks.cpp \ @@ -415,6 +422,7 @@ SOURCE_FILES = \ BoundsInference.cpp \ BoundSmallAllocations.cpp \ Buffer.cpp \ + Callable.cpp \ CanonicalizeGPUVars.cpp \ Closure.cpp \ ClampUnsafeAccesses.cpp \ @@ -493,7 +501,6 @@ SOURCE_FILES = \ Lower.cpp \ LowerParallelTasks.cpp \ LowerWarpShuffles.cpp \ - MatlabWrapper.cpp \ Memoization.cpp \ Module.cpp \ ModulusRemainder.cpp \ @@ -530,6 +537,7 @@ SOURCE_FILES = \ Simplify_And.cpp \ Simplify_Call.cpp \ Simplify_Cast.cpp \ + Simplify_Reinterpret.cpp \ Simplify_Div.cpp \ Simplify_EQ.cpp \ Simplify_Exprs.cpp \ @@ -550,6 +558,7 @@ SOURCE_FILES = \ SkipStages.cpp \ SlidingWindow.cpp \ Solve.cpp \ + SpirvIR.cpp \ SplitTuples.cpp \ StmtToHtml.cpp \ StorageFlattening.cpp \ @@ -574,8 +583,10 @@ SOURCE_FILES = \ # The externally-visible header files that go into making Halide.h. # Don't include anything here that includes llvm headers. +# Also *don't* include anything that's only used internally (eg SpirvIR.h). # Keep this list sorted in alphabetical order. HEADER_FILES = \ + AbstractGenerator.h \ AddAtomicMutex.h \ AddImageChecks.h \ AddParameterChecks.h \ @@ -592,6 +603,7 @@ HEADER_FILES = \ BoundsInference.h \ BoundSmallAllocations.h \ Buffer.h \ + Callable.h \ CanonicalizeGPUVars.h \ ClampUnsafeAccesses.h \ Closure.h \ @@ -673,7 +685,6 @@ HEADER_FILES = \ LowerParallelTasks.h \ LowerWarpShuffles.h \ MainPage.h \ - MatlabWrapper.h \ Memoization.h \ Module.h \ ModulusRemainder.h \ @@ -758,6 +769,7 @@ RUNTIME_CPP_COMPONENTS = \ fake_get_symbol \ fake_thread_pool \ float16_t \ + force_include_types \ fuchsia_clock \ fuchsia_host_cpu_count \ fuchsia_yield \ @@ -772,8 +784,6 @@ RUNTIME_CPP_COMPONENTS = \ linux_clock \ linux_host_cpu_count \ linux_yield \ - matlab \ - metadata \ metal \ metal_objc_arm \ metal_objc_x86 \ @@ -978,6 +988,8 @@ $(BIN_DIR)/build_halide_h: $(ROOT_DIR)/tools/build_halide_h.cpp -include $(OBJECTS:.o=.d) -include $(INITIAL_MODULES:.o=.d) +.SECONDARY: + # Compile generic 32- or 64-bit code # (The 'nacl' is a red herring. This is just a generic 32-bit little-endian target.) RUNTIME_TRIPLE_32 = "le32-unknown-nacl-unknown" @@ -1066,7 +1078,7 @@ $(BUILD_DIR)/initmod.%_32_debug.ll: $(SRC_DIR)/runtime/%.cpp $(BUILD_DIR)/clang_ @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) -g -DDEBUG_RUNTIME -O3 $(RUNTIME_CXX_FLAGS) -fpic -m32 -target $(RUNTIME_TRIPLE_32) -DCOMPILING_HALIDE_RUNTIME -DBITS_32 -emit-llvm -S $(SRC_DIR)/runtime/$*.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.$*_32_debug.d -ifneq (,$(findstring $(LLVM_VERSION_TIMES_10), 120 130)) +ifneq (,$(findstring $(LLVM_VERSION_TIMES_10), 130)) # For LLVM14+, we must add elementtype() annotations to some of our LLVM IR; # earlier versions either don't understand that keyword at all, or don't support # the uses we have for it. Rather than forking these sources, for now we'll just @@ -1136,12 +1148,11 @@ clean: rm -rf $(DISTRIB_DIR) rm -rf $(ROOT_DIR)/apps/*/bin -.SECONDARY: - CORRECTNESS_TESTS = $(shell ls $(ROOT_DIR)/test/correctness/*.cpp) $(shell ls $(ROOT_DIR)/test/correctness/*.c) PERFORMANCE_TESTS = $(shell ls $(ROOT_DIR)/test/performance/*.cpp) ERROR_TESTS = $(shell ls $(ROOT_DIR)/test/error/*.cpp) WARNING_TESTS = $(shell ls $(ROOT_DIR)/test/warning/*.cpp) +RUNTIME_TESTS = $(shell ls $(ROOT_DIR)/test/runtime/*.cpp) GENERATOR_EXTERNAL_TESTS := $(shell ls $(ROOT_DIR)/test/generator/*test.cpp) GENERATOR_EXTERNAL_TEST_GENERATOR := $(shell ls $(ROOT_DIR)/test/generator/*_generator.cpp) TUTORIALS = $(filter-out %_generate.cpp, $(shell ls $(ROOT_DIR)/tutorial/*.cpp)) @@ -1151,6 +1162,7 @@ test_correctness: $(CORRECTNESS_TESTS:$(ROOT_DIR)/test/correctness/%.cpp=quiet_c test_performance: $(PERFORMANCE_TESTS:$(ROOT_DIR)/test/performance/%.cpp=performance_%) test_error: $(ERROR_TESTS:$(ROOT_DIR)/test/error/%.cpp=error_%) test_warning: $(WARNING_TESTS:$(ROOT_DIR)/test/warning/%.cpp=warning_%) +test_runtime: $(RUNTIME_TESTS:$(ROOT_DIR)/test/runtime/%.cpp=runtime_%) test_tutorial: $(TUTORIALS:$(ROOT_DIR)/tutorial/%.cpp=tutorial_%) test_valgrind: $(CORRECTNESS_TESTS:$(ROOT_DIR)/test/correctness/%.cpp=valgrind_%) test_avx512: $(CORRECTNESS_TESTS:$(ROOT_DIR)/test/correctness/%.cpp=avx512_%) @@ -1201,12 +1213,6 @@ GENERATOR_AOTCPP_TESTS := $(filter-out generator_aotcpp_msan,$(GENERATOR_AOTCPP_ # https://github.com/halide/Halide/issues/2075 GENERATOR_AOTCPP_TESTS := $(filter-out generator_aotcpp_memory_profiler_mandelbrot,$(GENERATOR_AOTCPP_TESTS)) -# https://github.com/halide/Halide/issues/2082 -GENERATOR_AOTCPP_TESTS := $(filter-out generator_aotcpp_matlab,$(GENERATOR_AOTCPP_TESTS)) - -# https://github.com/halide/Halide/issues/2093 -GENERATOR_AOTCPP_TESTS := $(filter-out generator_aotcpp_async_parallel,$(GENERATOR_AOTCPP_TESTS)) - # https://github.com/halide/Halide/issues/4916 GENERATOR_AOTCPP_TESTS := $(filter-out generator_aotcpp_stubtest,$(GENERATOR_AOTCPP_TESTS)) GENERATOR_AOTCPP_TESTS := $(filter-out generator_aotcpp_stubuser,$(GENERATOR_AOTCPP_TESTS)) @@ -1222,7 +1228,6 @@ GENERATOR_BUILD_RUNGEN_TESTS = $(GENERATOR_EXTERNAL_TEST_GENERATOR:$(ROOT_DIR)/t GENERATOR_BUILD_RUNGEN_TESTS := $(filter-out $(FILTERS_DIR)/async_parallel.rungen,$(GENERATOR_BUILD_RUNGEN_TESTS)) GENERATOR_BUILD_RUNGEN_TESTS := $(filter-out $(FILTERS_DIR)/cxx_mangling_define_extern.rungen,$(GENERATOR_BUILD_RUNGEN_TESTS)) GENERATOR_BUILD_RUNGEN_TESTS := $(filter-out $(FILTERS_DIR)/define_extern_opencl.rungen,$(GENERATOR_BUILD_RUNGEN_TESTS)) -GENERATOR_BUILD_RUNGEN_TESTS := $(filter-out $(FILTERS_DIR)/matlab.rungen,$(GENERATOR_BUILD_RUNGEN_TESTS)) GENERATOR_BUILD_RUNGEN_TESTS := $(filter-out $(FILTERS_DIR)/msan.rungen,$(GENERATOR_BUILD_RUNGEN_TESTS)) GENERATOR_BUILD_RUNGEN_TESTS := $(filter-out $(FILTERS_DIR)/sanitizercoverage.rungen,$(GENERATOR_BUILD_RUNGEN_TESTS)) GENERATOR_BUILD_RUNGEN_TESTS := $(filter-out $(FILTERS_DIR)/multitarget.rungen,$(GENERATOR_BUILD_RUNGEN_TESTS)) @@ -1244,7 +1249,7 @@ test_generator: $(GENERATOR_AOT_TESTS) $(GENERATOR_AOTCPP_TESTS) $(GENERATOR_JIT $(FILTERS_DIR)/rungen_test $(FILTERS_DIR)/registration_test -ALL_TESTS = test_internal test_correctness test_error test_tutorial test_warning test_generator +ALL_TESTS = test_internal test_correctness test_error test_tutorial test_warning test_runtime test_generator # These targets perform timings of each test. For most tests this includes Halide JIT compile times, and run times. # For generator tests they time the compile time only. The times are recorded in CSV files. @@ -1265,6 +1270,7 @@ build_tests: $(CORRECTNESS_TESTS:$(ROOT_DIR)/test/correctness/%.cpp=$(BIN_DIR)/c $(PERFORMANCE_TESTS:$(ROOT_DIR)/test/performance/%.cpp=$(BIN_DIR)/performance_%) \ $(ERROR_TESTS:$(ROOT_DIR)/test/error/%.cpp=$(BIN_DIR)/error_%) \ $(WARNING_TESTS:$(ROOT_DIR)/test/warning/%.cpp=$(BIN_DIR)/warning_%) \ + $(RUNTIME_TESTS:$(ROOT_DIR)/test/runtime/%.cpp=$(BIN_DIR)/runtime_%) \ $(GENERATOR_EXTERNAL_TESTS:$(ROOT_DIR)/test/generator/%_aottest.cpp=$(BIN_DIR)/$(TARGET)/generator_aot_%) \ $(GENERATOR_EXTERNAL_TESTS:$(ROOT_DIR)/test/generator/%_jittest.cpp=$(BIN_DIR)/generator_jit_%) \ $(AUTO_SCHEDULE_TESTS:$(ROOT_DIR)/test/auto_schedule/%.cpp=$(BIN_DIR)/auto_schedule_%) @@ -1280,6 +1286,11 @@ clean_generator: time_compilation_tests: time_compilation_correctness time_compilation_performance time_compilation_generator +# These are just aliases to the autoscheduler plugins to make Generator rules & deps a little terser +BIN_ADAMS2019=$(BIN_DIR)/libautoschedule_adams2019.$(PLUGIN_EXT) +BIN_LI2018=$(BIN_DIR)/libautoschedule_li2018.$(PLUGIN_EXT) +BIN_MULLAPUDI2016=$(BIN_DIR)/libautoschedule_mullapudi2016.$(PLUGIN_EXT) + $(BUILD_DIR)/GenGen.o: $(ROOT_DIR)/tools/GenGen.cpp $(INCLUDE_DIR)/Halide.h @mkdir -p $(@D) $(CXX) -c $< $(TEST_CXX_FLAGS) -I$(INCLUDE_DIR) -o $@ @@ -1332,6 +1343,11 @@ $(BIN_DIR)/error_%: $(ROOT_DIR)/test/error/%.cpp $(BIN_DIR)/libHalide.$(SHARED_E $(BIN_DIR)/warning_%: $(ROOT_DIR)/test/warning/%.cpp $(BIN_DIR)/libHalide.$(SHARED_EXT) $(INCLUDE_DIR)/Halide.h $(CXX) $(TEST_CXX_FLAGS) $(OPTIMIZE_FOR_BUILD_TIME) $< -I$(INCLUDE_DIR) $(TEST_LD_FLAGS) -o $@ +# Runtime tests that test internals +RUNTIME_TESTS_CXXFLAGS = -fno-rtti -fno-exceptions -fno-threadsafe-statics -Wno-builtin-declaration-mismatch -DCOMPILING_HALIDE_RUNTIME -DCOMPILING_HALIDE_RUNTIME_TESTS +$(BIN_DIR)/runtime_%: $(ROOT_DIR)/test/runtime/%.cpp $(ROOT_DIR)/test/runtime/common.h + $(CXX) $(TEST_CXX_FLAGS) $(RUNTIME_TESTS_CXXFLAGS) -I$(ROOT_DIR)/test/runtime -I$(ROOT_DIR)/src/runtime $(OPTIMIZE_FOR_BUILD_TIME) $< $(COMMON_LD_FLAGS) -o $@ + # Auto schedule tests that link against libHalide $(BIN_DIR)/auto_schedule_%: $(ROOT_DIR)/test/auto_schedule/%.cpp $(BIN_DIR)/libHalide.$(SHARED_EXT) $(INCLUDE_DIR)/Halide.h $(CXX) $(TEST_CXX_FLAGS) $(OPTIMIZE_FOR_BUILD_TIME) $< -I$(INCLUDE_DIR) $(TEST_LD_FLAGS) -o $@ @@ -1443,6 +1459,18 @@ $(FILTERS_DIR)/alias_with_offset_42.a: $(BIN_DIR)/alias.generator @mkdir -p $(@D) $(CURDIR)/$< -g alias_with_offset_42 -f alias_with_offset_42 $(GEN_AOT_OUTPUTS) -o $(CURDIR)/$(FILTERS_DIR) target=$(TARGET)-no_runtime +$(FILTERS_DIR)/alias_Adams2019.a: $(BIN_DIR)/alias.generator autoschedulers + @mkdir -p $(@D) + $(CURDIR)/$< -g alias_Adams2019 -f alias_Adams2019 $(GEN_AOT_OUTPUTS) -o $(CURDIR)/$(FILTERS_DIR) target=$(TARGET)-no_runtime -p $(BIN_ADAMS2019) + +$(FILTERS_DIR)/alias_Li2018.a: $(BIN_DIR)/alias.generator autoschedulers + @mkdir -p $(@D) + $(CURDIR)/$< -g alias_Li2018 -f alias_Li2018 $(GEN_AOT_OUTPUTS) -o $(CURDIR)/$(FILTERS_DIR) target=$(TARGET)-no_runtime -p $(BIN_LI2018) + +$(FILTERS_DIR)/alias_Mullapudi2016.a: $(BIN_DIR)/alias.generator autoschedulers + @mkdir -p $(@D) + $(CURDIR)/$< -g alias_Mullapudi2016 -f alias_Mullapudi2016 $(GEN_AOT_OUTPUTS) -o $(CURDIR)/$(FILTERS_DIR) target=$(TARGET)-no_runtime -p $(BIN_MULLAPUDI2016) + METADATA_TESTER_GENERATOR_ARGS=\ input.type=uint8 input.dim=3 \ dim_only_input_buffer.type=uint8 \ @@ -1512,11 +1540,6 @@ $(FILTERS_DIR)/user_context_insanity.a: $(BIN_DIR)/user_context_insanity.generat @mkdir -p $(@D) $(CURDIR)/$< -g user_context_insanity $(GEN_AOT_OUTPUTS) -o $(CURDIR)/$(FILTERS_DIR) target=$(TARGET)-no_runtime-user_context -# matlab needs to be generated with matlab in TARGET -$(FILTERS_DIR)/matlab.a: $(BIN_DIR)/matlab.generator - @mkdir -p $(@D) - $(CURDIR)/$< -g matlab $(GEN_AOT_OUTPUTS) -o $(CURDIR)/$(FILTERS_DIR) target=$(TARGET)-no_runtime-matlab - # Some .generators have additional dependencies (usually due to define_extern usage). # These typically require two extra dependencies: # (1) Ensuring the extra _generator.cpp is built into the .generator. @@ -1555,6 +1578,10 @@ $(FILTERS_DIR)/stubtest.a: $(BIN_DIR)/stubtest.generator @mkdir -p $(@D) $(CURDIR)/$< -g stubtest -f stubtest $(GEN_AOT_OUTPUTS) -o $(CURDIR)/$(FILTERS_DIR) target=$(TARGET)-no_runtime $(STUBTEST_GENERATOR_ARGS) +$(FILTERS_DIR)/stubuser_auto.a: $(BIN_DIR)/stubuser.generator $(BIN_MULLAPUDI2016) + @mkdir -p $(@D) + $(CURDIR)/$< -g stubuser $(GEN_AOT_OUTPUTS) -o $(CURDIR)/$(FILTERS_DIR) -f stubuser_auto target=$(TARGET)-no_runtime autoscheduler=Mullapudi2016 -p $(BIN_MULLAPUDI2016) + $(FILTERS_DIR)/external_code.a: $(BIN_DIR)/external_code.generator @mkdir -p $(@D) $(CURDIR)/$< -g external_code -e static_library,c_header,registration -o $(CURDIR)/$(FILTERS_DIR) target=$(TARGET)-no_runtime external_code_is_bitcode=true @@ -1563,15 +1590,9 @@ $(FILTERS_DIR)/external_code.halide_generated.cpp: $(BIN_DIR)/external_code.gene @mkdir -p $(@D) $(CURDIR)/$< -g external_code -e c_source -o $(CURDIR)/$(FILTERS_DIR) target=$(TARGET)-no_runtime external_code_is_bitcode=false -$(FILTERS_DIR)/autograd_grad.a: $(BIN_DIR)/autograd.generator $(DISTRIB_DIR)/lib/libautoschedule_mullapudi2016.$(SHARED_EXT) +$(FILTERS_DIR)/autograd_grad.a: $(BIN_DIR)/autograd.generator $(BIN_MULLAPUDI2016) @mkdir -p $(@D) - # FIXME: The autoscheduler looks for libHalide in the same - # directory, which is normally a distro. But the generator - # tests use bin/libHalide.so instead of a distro. For now, - # just copy the autoscheduler to a place where it won't - # confuse the linker. - cp $(DISTRIB_DIR)/lib/libautoschedule_mullapudi2016.$(SHARED_EXT) $(BIN_DIR) - $(CURDIR)/$< -g autograd $(GEN_AOT_OUTPUTS) -o $(CURDIR)/$(FILTERS_DIR) -f autograd_grad target=$(TARGET)-no_runtime auto_schedule=true -d 1 -p $(BIN_DIR)/libautoschedule_mullapudi2016.$(SHARED_EXT) -s Mullapudi2016 + $(CURDIR)/$< -g autograd $(GEN_AOT_OUTPUTS) -o $(CURDIR)/$(FILTERS_DIR) -f autograd_grad target=$(TARGET)-no_runtime autoscheduler=Mullapudi2016 -d 1 -p $(BIN_MULLAPUDI2016) # Usually, it's considered best practice to have one Generator per # .cpp file, with the generator-name and filename matching; @@ -1618,12 +1639,13 @@ $(BIN_DIR)/$(TARGET)/generator_aot_sanitizercoverage: $(ROOT_DIR)/test/generator @mkdir -p $(@D) $(CXX) $(GEN_AOT_CXX_FLAGS) $(filter-out %.h,$^) $(GEN_AOT_INCLUDES) $(GEN_AOT_LD_FLAGS) -o $@ + # alias has additional deps to link in -$(BIN_DIR)/$(TARGET)/generator_aot_alias: $(ROOT_DIR)/test/generator/alias_aottest.cpp $(FILTERS_DIR)/alias.a $(FILTERS_DIR)/alias_with_offset_42.a $(RUNTIME_EXPORTED_INCLUDES) $(BIN_DIR)/$(TARGET)/runtime.a +$(BIN_DIR)/$(TARGET)/generator_aot_alias: $(ROOT_DIR)/test/generator/alias_aottest.cpp $(FILTERS_DIR)/alias.a $(FILTERS_DIR)/alias_with_offset_42.a $(FILTERS_DIR)/alias_Adams2019.a $(FILTERS_DIR)/alias_Li2018.a $(FILTERS_DIR)/alias_Mullapudi2016.a $(RUNTIME_EXPORTED_INCLUDES) $(BIN_DIR)/$(TARGET)/runtime.a @mkdir -p $(@D) $(CXX) $(GEN_AOT_CXX_FLAGS) $(filter %.cpp %.o %.a,$^) $(GEN_AOT_INCLUDES) $(GEN_AOT_LD_FLAGS) -o $@ -$(BIN_DIR)/$(TARGET)/generator_aotcpp_alias: $(ROOT_DIR)/test/generator/alias_aottest.cpp $(FILTERS_DIR)/alias.halide_generated.cpp $(FILTERS_DIR)/alias_with_offset_42.halide_generated.cpp $(RUNTIME_EXPORTED_INCLUDES) $(BIN_DIR)/$(TARGET)/runtime.a +$(BIN_DIR)/$(TARGET)/generator_aotcpp_alias: $(ROOT_DIR)/test/generator/alias_aottest.cpp $(FILTERS_DIR)/alias.halide_generated.cpp $(FILTERS_DIR)/alias_with_offset_42.halide_generated.cpp $(FILTERS_DIR)/alias_Adams2019.halide_generated.cpp $(FILTERS_DIR)/alias_Li2018.halide_generated.cpp $(FILTERS_DIR)/alias_Mullapudi2016.halide_generated.cpp $(RUNTIME_EXPORTED_INCLUDES) $(BIN_DIR)/$(TARGET)/runtime.a @mkdir -p $(@D) $(CXX) $(GEN_AOT_CXX_FLAGS) $(filter %.cpp %.o %.a,$^) $(GEN_AOT_INCLUDES) $(GEN_AOT_LD_FLAGS) -o $@ @@ -1645,15 +1667,6 @@ $(BIN_DIR)/$(TARGET)/generator_aotcpp_nested_externs: $(ROOT_DIR)/test/generator @mkdir -p $(@D) $(CXX) $(GEN_AOT_CXX_FLAGS) $(filter %.cpp %.o %.a,$^) $(GEN_AOT_INCLUDES) $(GEN_AOT_LD_FLAGS) -o $@ -# The matlab tests needs "-matlab" in the runtime -$(BIN_DIR)/$(TARGET)/generator_aot_matlab: $(ROOT_DIR)/test/generator/matlab_aottest.cpp $(FILTERS_DIR)/matlab.a $(FILTERS_DIR)/matlab.h $(RUNTIME_EXPORTED_INCLUDES) $(BIN_DIR)/$(TARGET)-matlab/runtime.a - @mkdir -p $(@D) - $(CXX) $(GEN_AOT_CXX_FLAGS) $(filter %.cpp %.o %.a,$^) $(GEN_AOT_INCLUDES) $(GEN_AOT_LD_FLAGS) $(TEST_LD_FLAGS) -o $@ - -$(BIN_DIR)/$(TARGET)/generator_aotcpp_matlab: $(ROOT_DIR)/test/generator/matlab_aottest.cpp $(FILTERS_DIR)/matlab.halide_generated.cpp $(FILTERS_DIR)/matlab.h $(RUNTIME_EXPORTED_INCLUDES) $(BIN_DIR)/$(TARGET)-matlab/runtime.a - @mkdir -p $(@D) - $(CXX) $(GEN_AOT_CXX_FLAGS) $(filter %.cpp %.o %.a,$^) $(GEN_AOT_INCLUDES) $(GEN_AOT_LD_FLAGS) $(TEST_LD_FLAGS) -o $@ - # The gpu object lifetime test needs the debug runtime $(BIN_DIR)/$(TARGET)/generator_aot_gpu_object_lifetime: $(ROOT_DIR)/test/generator/gpu_object_lifetime_aottest.cpp $(FILTERS_DIR)/gpu_object_lifetime.a $(FILTERS_DIR)/gpu_object_lifetime.h $(RUNTIME_EXPORTED_INCLUDES) $(BIN_DIR)/$(TARGET)-debug/runtime.a @mkdir -p $(@D) @@ -1682,6 +1695,11 @@ $(BIN_DIR)/generator_jit_%: $(ROOT_DIR)/test/generator/%_jittest.cpp $(BIN_DIR)/ @mkdir -p $(@D) $(CXX) -g $(TEST_CXX_FLAGS) $(filter %.cpp %.o %.a,$^) -I$(INCLUDE_DIR) -I$(FILTERS_DIR) -I $(ROOT_DIR)/apps/support $(TEST_LD_FLAGS) -o $@ +# stubuser is run with autoscheduling too +$(BIN_DIR)/$(TARGET)/generator_aot_stubuser: $(ROOT_DIR)/test/generator/stubuser_aottest.cpp $(FILTERS_DIR)/stubuser.a $(FILTERS_DIR)/stubuser.h $(FILTERS_DIR)/stubuser_auto.a $(FILTERS_DIR)/stubuser_auto.h $(RUNTIME_EXPORTED_INCLUDES) $(BIN_DIR)/$(TARGET)/runtime.a + @mkdir -p $(@D) + $(CXX) $(GEN_AOT_CXX_FLAGS) $(filter %.cpp %.o %.a,$^) $(GEN_AOT_INCLUDES) $(GEN_AOT_LD_FLAGS) -o $@ + # generator_aot_multitarget is run multiple times, with different env vars. generator_aot_multitarget: $(BIN_DIR)/$(TARGET)/generator_aot_multitarget @mkdir -p $(@D) @@ -1852,21 +1870,23 @@ $(BIN_DIR)/tutorial_lesson_21_auto_scheduler_generate: $(ROOT_DIR)/tutorial/less $(CXX) $(TUTORIAL_CXX_FLAGS) $(IMAGE_IO_CXX_FLAGS) $(OPTIMIZE_FOR_BUILD_TIME) $< $(BUILD_DIR)/GenGen.o \ -I$(INCLUDE_DIR) $(TEST_LD_FLAGS) $(IMAGE_IO_LIBS) -o $@ -# The values in MachineParams are: +# The values are: # - the maximum level of parallelism available, # - the size of the last-level cache (in bytes), # - the ratio between the cost of a miss at the last level cache and the cost # of arithmetic on the target architecture # ...in that order. -LESSON_21_MACHINE_PARAMS = 32,16777216,40 +LESSON_21_AUTOSCHEDULER_PARAMS=\ + autoscheduler=Mullapudi2016 \ + autoscheduler.parallelism=32 \ + autoscheduler.last_level_cache_size=16777216 \ + autoscheduler.balance=40 -$(BIN_DIR)/tutorial_lesson_21_auto_scheduler_run: $(ROOT_DIR)/tutorial/lesson_21_auto_scheduler_run.cpp $(BIN_DIR)/tutorial_lesson_21_auto_scheduler_generate $(DISTRIB_DIR)/lib/libautoschedule_mullapudi2016.$(SHARED_EXT) +$(BIN_DIR)/tutorial_lesson_21_auto_scheduler_run: $(ROOT_DIR)/tutorial/lesson_21_auto_scheduler_run.cpp $(BIN_DIR)/tutorial_lesson_21_auto_scheduler_generate $(BIN_MULLAPUDI2016) @-mkdir -p $(TMP_DIR) # Run the generator - $(BIN_DIR)/tutorial_lesson_21_auto_scheduler_generate -g auto_schedule_gen -o $(TMP_DIR) -e static_library,c_header,schedule -f auto_schedule_false target=host auto_schedule=false - # FIXME: The relative path of the autoscheduler and libHalide must be preserved on OS X, or it tries to load the wrong libHalide.dylib - cp $(DISTRIB_DIR)/lib/libautoschedule_mullapudi2016.$(SHARED_EXT) $(BIN_DIR) - $(BIN_DIR)/tutorial_lesson_21_auto_scheduler_generate -g auto_schedule_gen -o $(TMP_DIR) -e static_library,c_header,schedule -f auto_schedule_true target=host-no_runtime auto_schedule=true machine_params=$(LESSON_21_MACHINE_PARAMS) -p $(BIN_DIR)/libautoschedule_mullapudi2016.$(SHARED_EXT) -s Mullapudi2016 + $(BIN_DIR)/tutorial_lesson_21_auto_scheduler_generate -g auto_schedule_gen -o $(TMP_DIR) -e static_library,c_header,schedule -f auto_schedule_false target=host + $(BIN_DIR)/tutorial_lesson_21_auto_scheduler_generate -g auto_schedule_gen -o $(TMP_DIR) -e static_library,c_header,schedule -f auto_schedule_true target=host-no_runtime $(LESSON_21_AUTOSCHEDULER_PARAMS) -p $(BIN_MULLAPUDI2016) # Compile the runner $(CXX) $(TUTORIAL_CXX_FLAGS) $(IMAGE_IO_CXX_FLAGS) $(OPTIMIZE_FOR_BUILD_TIME) $< \ -I$(INCLUDE_DIR) -L$(BIN_DIR) -I $(TMP_DIR) $(TMP_DIR)/auto_schedule_*.a \ @@ -1925,6 +1945,11 @@ warning_%: $(BIN_DIR)/warning_% cd $(TMP_DIR) ; $(CURDIR)/$< 2>&1 | egrep --q "^Warning" @-echo +runtime_%: $(BIN_DIR)/runtime_% + @-mkdir -p $(TMP_DIR) + cd $(TMP_DIR) ; $(CURDIR)/$< + @-echo + generator_jit_%: $(BIN_DIR)/generator_jit_% @-mkdir -p $(TMP_DIR) cd $(TMP_DIR) ; $(CURDIR)/$< @@ -1953,9 +1978,9 @@ test_mullapudi2016: $(AUTO_SCHEDULE_TESTS:$(ROOT_DIR)/test/auto_schedule/%.cpp=a # These tests were written for the Mullapudi2016 autoscheduler. # TODO: either make them work with all autoschedulers or move them under src/autoschedulers/mullapudi2016 -auto_schedule_%: $(BIN_DIR)/auto_schedule_% $(BIN_DIR)/libautoschedule_mullapudi2016.$(SHARED_EXT) +auto_schedule_%: $(BIN_DIR)/auto_schedule_% $(BIN_MULLAPUDI2016) @-mkdir -p $(TMP_DIR) - cd $(TMP_DIR) ; $(CURDIR)/$< $(realpath $(BIN_DIR))/libautoschedule_mullapudi2016.$(SHARED_EXT) + cd $(TMP_DIR) ; $(CURDIR)/$< $(realpath $(BIN_MULLAPUDI2016)) @-echo # The other autoschedulers contain their own tests @@ -1967,10 +1992,9 @@ test_adams2019: distrib $(MAKE) -f $(SRC_DIR)/autoschedulers/adams2019/Makefile test \ HALIDE_DISTRIB_PATH=$(CURDIR)/$(DISTRIB_DIR) -test_li2018: distrib build_python_bindings +test_li2018: distrib $(MAKE) -f $(SRC_DIR)/autoschedulers/li2018/Makefile test \ - HALIDE_DISTRIB_PATH=$(CURDIR)/$(DISTRIB_DIR) \ - HALIDE_PYTHON_BINDINGS_PATH=$(CURDIR)/$(BIN_DIR)/python3_bindings + HALIDE_DISTRIB_PATH=$(CURDIR)/$(DISTRIB_DIR) time_compilation_test_%: $(BIN_DIR)/test_% $(TIME_COMPILATION) compile_times_correctness.csv make -f $(THIS_MAKEFILE) $(@:time_compilation_test_%=test_%) @@ -1982,7 +2006,6 @@ time_compilation_generator_%: $(BIN_DIR)/%.generator $(TIME_COMPILATION) compile_times_generator.csv make -f $(THIS_MAKEFILE) $(@:time_compilation_generator_%=$(FILTERS_DIR)/%.a) TEST_APPS=\ - HelloMatlab \ bilateral_grid \ bgu \ blur \ @@ -2006,20 +2029,18 @@ TEST_APPS=\ TEST_APPS_DEPS=$(TEST_APPS:%=%_test_app) BUILD_APPS_DEPS=$(TEST_APPS:%=%_build_app) -$(BUILD_APPS_DEPS): distrib build_python_bindings +$(BUILD_APPS_DEPS): distrib @echo Building app $(@:%_build_app=%) for ${HL_TARGET}... @$(MAKE) -C $(ROOT_DIR)/apps/$(@:%_build_app=%) build \ HALIDE_DISTRIB_PATH=$(CURDIR)/$(DISTRIB_DIR) \ - HALIDE_PYTHON_BINDINGS_PATH=$(CURDIR)/$(BIN_DIR)/python3_bindings \ BIN_DIR=$(CURDIR)/$(BIN_DIR)/apps/$(@:%_build_app=%)/bin \ HL_TARGET=$(HL_TARGET) \ || exit 1 ; \ -$(TEST_APPS_DEPS): distrib build_python_bindings +$(TEST_APPS_DEPS): distrib @echo Testing app $(@:%_test_app=%) for ${HL_TARGET}... @$(MAKE) -C $(ROOT_DIR)/apps/$(@:%_test_app=%) test \ HALIDE_DISTRIB_PATH=$(CURDIR)/$(DISTRIB_DIR) \ - HALIDE_PYTHON_BINDINGS_PATH=$(CURDIR)/$(BIN_DIR)/python3_bindings \ BIN_DIR=$(CURDIR)/$(BIN_DIR)/apps/$(@:%_test_app=%)/bin \ HL_TARGET=$(HL_TARGET) \ || exit 1 ; \ @@ -2034,7 +2055,6 @@ build_hannk: distrib @echo Building apps/hannk for ${HL_TARGET}... @$(MAKE) -C $(ROOT_DIR)/apps/hannk build \ HALIDE_DISTRIB_PATH=$(CURDIR)/$(DISTRIB_DIR) \ - HALIDE_PYTHON_BINDINGS_PATH=$(CURDIR)/$(BIN_DIR)/python3_bindings \ BIN_DIR=$(CURDIR)/$(BIN_DIR)/apps/hannk/bin \ HL_TARGET=$(HL_TARGET) \ || exit 1 ; \ @@ -2043,7 +2063,6 @@ test_hannk: build_hannk @echo Testing apps/hannk for ${HL_TARGET}... @$(MAKE) -C $(ROOT_DIR)/apps/hannk test \ HALIDE_DISTRIB_PATH=$(CURDIR)/$(DISTRIB_DIR) \ - HALIDE_PYTHON_BINDINGS_PATH=$(CURDIR)/$(BIN_DIR)/python3_bindings \ BIN_DIR=$(CURDIR)/$(BIN_DIR)/apps/hannk/bin \ HL_TARGET=$(HL_TARGET) \ || exit 1 ; \ @@ -2056,12 +2075,11 @@ BENCHMARK_APPS=\ nl_means \ stencil_chain -$(BENCHMARK_APPS): distrib build_python_bindings +$(BENCHMARK_APPS): distrib @echo Building $@ for ${HL_TARGET}... @$(MAKE) -C $(ROOT_DIR)/apps/$@ \ $(CURDIR)/$(BIN_DIR)/apps/$@/bin/$(HL_TARGET)/$@.rungen \ HALIDE_DISTRIB_PATH=$(CURDIR)/$(DISTRIB_DIR) \ - HALIDE_PYTHON_BINDINGS_PATH=$(CURDIR)/$(BIN_DIR)/python3_bindings \ BIN_DIR=$(CURDIR)/$(BIN_DIR)/apps/$@/bin \ HL_TARGET=$(HL_TARGET) \ > /dev/null \ @@ -2075,34 +2093,11 @@ benchmark_apps: $(BENCHMARK_APPS) make -C $(ROOT_DIR)/apps/$${APP} \ $${APP}.benchmark \ HALIDE_DISTRIB_PATH=$(CURDIR)/$(DISTRIB_DIR) \ - HALIDE_PYTHON_BINDINGS_PATH=$(CURDIR)/$(BIN_DIR)/python3_bindings \ BIN_DIR=$(CURDIR)/$(BIN_DIR)/apps/$${APP}/bin \ HL_TARGET=$(HL_TARGET) \ || exit 1 ; \ done -# TODO(srj): the python bindings need to be put into the distrib folders; -# this is a hopefully-temporary workaround (https://github.com/halide/Halide/issues/4368) -.PHONY: build_python_bindings -build_python_bindings: distrib $(BIN_DIR)/host/runtime.a - $(MAKE) -C $(ROOT_DIR)/python_bindings \ - -f $(ROOT_DIR)/python_bindings/Makefile \ - build_python_bindings \ - HALIDE_DISTRIB_PATH=$(CURDIR)/$(DISTRIB_DIR) \ - BIN=$(CURDIR)/$(BIN_DIR)/python3_bindings \ - PYTHON=$(PYTHON) \ - OPTIMIZE="$(OPTIMIZE)" - -.PHONY: test_python -test_python: distrib $(BIN_DIR)/host/runtime.a build_python_bindings - $(MAKE) -C $(ROOT_DIR)/python_bindings \ - -f $(ROOT_DIR)/python_bindings/Makefile \ - test \ - HALIDE_DISTRIB_PATH=$(CURDIR)/$(DISTRIB_DIR) \ - BIN=$(CURDIR)/$(BIN_DIR)/python3_bindings \ - PYTHON=$(PYTHON) \ - OPTIMIZE="$(OPTIMIZE)" - # It's just for compiling the runtime, so earlier clangs *might* work, # but best to peg it to the minimum llvm version. ifneq (,$(findstring clang version 3.7,$(CLANG_VERSION))) @@ -2169,6 +2164,10 @@ ifneq (,$(findstring clang version 15.0,$(CLANG_VERSION))) CLANG_OK=yes endif +ifneq (,$(findstring clang version 16.0,$(CLANG_VERSION))) +CLANG_OK=yes +endif + ifneq (,$(findstring Apple LLVM version 5.0,$(CLANG_VERSION))) CLANG_OK=yes endif @@ -2189,7 +2188,7 @@ $(BUILD_DIR)/clang_ok: @exit 1 endif -ifneq (,$(findstring $(LLVM_VERSION_TIMES_10), 120 130 140, 150)) +ifneq (,$(findstring $(LLVM_VERSION_TIMES_10), 130 140 150 160)) LLVM_OK=yes endif @@ -2239,7 +2238,6 @@ install: $(LIB_DIR)/libHalide.a $(BIN_DIR)/libHalide.$(SHARED_EXT) $(INCLUDE_DIR cp $(ROOT_DIR)/tutorial/*.cpp $(PREFIX)/share/halide/tutorial cp $(ROOT_DIR)/tutorial/*.h $(PREFIX)/share/halide/tutorial cp $(ROOT_DIR)/tutorial/*.sh $(PREFIX)/share/halide/tutorial - cp $(ROOT_DIR)/tools/mex_halide.m $(PREFIX)/share/halide/tools cp $(ROOT_DIR)/tools/GenGen.cpp $(PREFIX)/share/halide/tools cp $(ROOT_DIR)/tools/RunGen.h $(PREFIX)/share/halide/tools cp $(ROOT_DIR)/tools/RunGenMain.cpp $(PREFIX)/share/halide/tools @@ -2317,7 +2315,6 @@ $(DISTRIB_DIR)/lib/libHalide.$(SHARED_EXT): \ cp $(ROOT_DIR)/tutorial/*.cpp $(DISTRIB_DIR)/tutorial cp $(ROOT_DIR)/tutorial/*.h $(DISTRIB_DIR)/tutorial cp $(ROOT_DIR)/tutorial/*.sh $(DISTRIB_DIR)/tutorial - cp $(ROOT_DIR)/tools/mex_halide.m $(DISTRIB_DIR)/tools cp $(ROOT_DIR)/tools/GenGen.cpp $(DISTRIB_DIR)/tools cp $(ROOT_DIR)/tools/RunGen.h $(DISTRIB_DIR)/tools cp $(ROOT_DIR)/tools/RunGenMain.cpp $(DISTRIB_DIR)/tools @@ -2333,17 +2330,25 @@ ifeq ($(UNAME), Darwin) install_name_tool -id @rpath/libHalide.$(SHARED_EXT) $(DISTRIB_DIR)/lib/libHalide.$(SHARED_EXT) endif -$(DISTRIB_DIR)/lib/libautoschedule_%.$(SHARED_EXT): $(DISTRIB_DIR)/lib/libHalide.$(SHARED_EXT) - $(MAKE) -f $(SRC_DIR)/autoschedulers/$*/Makefile bin/libautoschedule_$*.$(SHARED_EXT) HALIDE_DISTRIB_PATH=$(CURDIR)/$(DISTRIB_DIR) - cp $(BIN_DIR)/libautoschedule_$*.$(SHARED_EXT) $(DISTRIB_DIR)/lib +$(BIN_DIR)/libautoschedule_%.$(PLUGIN_EXT): $(DISTRIB_DIR)/lib/libHalide.$(SHARED_EXT) + $(MAKE) -f $(SRC_DIR)/autoschedulers/$*/Makefile $@ HALIDE_DISTRIB_PATH=$(CURDIR)/$(DISTRIB_DIR) +ifeq ($(UNAME), Darwin) + install_name_tool -id @rpath/$(@F) $(CURDIR)/$@ +endif + + +$(DISTRIB_DIR)/lib/libautoschedule_%.$(PLUGIN_EXT): $(BIN_DIR)/libautoschedule_%.$(PLUGIN_EXT) + @mkdir -p $(@D) + cp $< $(DISTRIB_DIR)/lib ifeq ($(UNAME), Darwin) install_name_tool -id @rpath/$(@F) $(CURDIR)/$@ endif # Adams2019 also includes autotuning tools -$(DISTRIB_DIR)/lib/libautoschedule_adams2019.$(SHARED_EXT): $(DISTRIB_DIR)/lib/libHalide.$(SHARED_EXT) - $(MAKE) -f $(SRC_DIR)/autoschedulers/adams2019/Makefile bin/libautoschedule_adams2019.$(SHARED_EXT) HALIDE_DISTRIB_PATH=$(CURDIR)/$(DISTRIB_DIR) bin/retrain_cost_model bin/featurization_to_sample bin/get_host_target - cp $(BIN_DIR)/libautoschedule_adams2019.$(SHARED_EXT) $(DISTRIB_DIR)/lib/ +$(DISTRIB_DIR)/lib/libautoschedule_adams2019.$(PLUGIN_EXT): $(BIN_DIR)/libautoschedule_adams2019.$(PLUGIN_EXT) + @mkdir -p $(@D) + $(MAKE) -f $(SRC_DIR)/autoschedulers/adams2019/Makefile $(BIN_DIR)/retrain_cost_model $(BIN_DIR)/featurization_to_sample $(BIN_DIR)/get_host_target HALIDE_DISTRIB_PATH=$(CURDIR)/$(DISTRIB_DIR) + cp $< $(DISTRIB_DIR)/lib/ for TOOL in retrain_cost_model featurization_to_sample get_host_target; do \ cp $(BIN_DIR)/$${TOOL} $(DISTRIB_DIR)/bin/; \ done diff --git a/README.md b/README.md index df64867f4a36..7ae0b66ecdff 100644 --- a/README.md +++ b/README.md @@ -31,10 +31,12 @@ If you've acquired a full source distribution and want to build Halide, see the ## Binary tarballs -The latest version of Halide is **Halide 13.0.0**. We provide binary releases -for many popular platforms and architectures, including 32/64-bit x86 Windows, -64-bit macOS, and 32/64-bit x86/ARM Ubuntu Linux. See the releases tab on the -right (or click [here](https://github.com/halide/Halide/releases)). +The latest version of Halide can always be found on GitHub +at https://github.com/halide/Halide/releases + +We provide binary releases for many popular platforms and architectures, +including 32/64-bit x86 Windows, 64-bit macOS, and 32/64-bit x86/ARM +Ubuntu Linux. ## Vcpkg @@ -80,7 +82,7 @@ These are the **tested** host toolchain and platform combinations for building and running the Halide compiler library. | Compiler | Version | OS | Architectures | -| ---------- | ------------ | ---------------------- | --------------- | +|------------|--------------|------------------------|-----------------| | GCC | 7.5 | Ubuntu Linux 20.04 LTS | x86, x64, ARM32 | | GCC | 7.5 | Ubuntu Linux 18.04 LTS | ARM32, ARM64 | | MSVC | 2019 (19.28) | Windows 10 (20H2) | x86, x64 | @@ -107,14 +109,14 @@ issue. ### TL;DR -Have llvm-12.0 (or greater) installed and run `make` in the root directory of +Have llvm-13.0 (or greater) installed and run `make` in the root directory of the repository (where this README is). ### Acquiring LLVM At any point in time, building Halide requires either the latest stable version of LLVM, the previous stable version of LLVM, and trunk. At the time of writing, -this means versions 13.0 and 12.0 are supported, but 11.0 is not. The commands +this means versions 14.0 and 13.0 are supported, but 12.0 is not. The commands `llvm-config` and `clang` must be somewhere in the path. If your OS does not have packages for LLVM, you can find binaries for it at @@ -128,7 +130,7 @@ If you want to build it yourself, first check it out from GitHub: % git clone --depth 1 --branch llvmorg-13.0.0 https://github.com/llvm/llvm-project.git ``` -(If you want to build LLVM 12.x, use branch `release/12.x`; for current trunk, +(If you want to build LLVM 13.x, use branch `release/13.x`; for current trunk, use `main`) Then build it like so: @@ -283,7 +285,7 @@ Subsets of the tests can be selected with `-L` and include `correctness`, #### Building LLVM (optional) Follow these steps if you want to build LLVM yourself. First, download LLVM's -sources (these instructions use the latest 12.0 release) +sources (these instructions use the latest 13.0 release) ``` D:\> git clone --depth 1 --branch llvmorg-13.0.0 https://github.com/llvm/llvm-project.git diff --git a/README_cmake.md b/README_cmake.md index 1eeea286fd2d..3e6170f37ab7 100644 --- a/README_cmake.md +++ b/README_cmake.md @@ -67,7 +67,7 @@ we strongly suggest reading through the [CMake documentation][cmake-docs] first. ## Installing CMake -Halide requires at least version 3.16, which was released in November 2019. +Halide requires at least version 3.22, which was released in November 2021. Fortunately, getting a recent version of CMake couldn't be easier, and there are multiple good options on any system to do so. Generally, one should always have the most recent version of CMake installed system-wide. CMake is committed to @@ -116,8 +116,8 @@ is also a viable option. There are a few good ways to install a modern CMake on Ubuntu: -1. If you're on Ubuntu Linux 20.04 (focal), then simply running - `sudo apt install cmake` will get you CMake 3.16. +1. If you're on Ubuntu Linux 22.04 (Jammy Jellyfish), then simply running + `sudo apt install cmake` will get you CMake 3.22. 2. If you are on an older Ubuntu release or would like to use the newest CMake, try installing via the snap store: `snap install cmake`. Be sure you do not already have `cmake` installed via APT. The snap package automatically stays @@ -332,24 +332,20 @@ standard types: `Debug`, `RelWithDebInfo`, `MinSizeRel`, or `Release`. ### CMake Presets -If you are using CMake 3.19+, we provide several [presets][cmake_presets] to +If you are using CMake 3.21+, we provide several [presets][cmake_presets] to make the above commands more convenient. The following CMake preset commands correspond to the longer ones above. ``` -> cmake --preset=msvc-release # Ninja generator, MSVC compiler, Release build -> cmake --preset=win64 # VS 2019 generator, 64-bit build -> cmake --preset=win32 # VS 2019 generator, 32-bit build -$ cmake --preset=gcc-release # Ninja generator, GCC compiler, Release build +> cmake --preset=win64 # VS 2019 generator, 64-bit build, vcpkg deps +> cmake --preset=win32 # VS 2019 generator, 32-bit build, vcpkg deps +> cmake --preset=release # Release mode, any single-config generator / compiler -$ cmake --list-presets # Get full list of presets. +$ cmake --list-presets # Get full list of presets. ``` -The Windows and MSVC presets assume that the environment variable `VCPKG_ROOT` -is set and points to the root of the vcpkg installation. - -Note that the GCC presets do not define `NDEBUG` in release configurations, -departing from the usual CMake behavior. +The Windows presets assume that the environment variable `VCPKG_ROOT` is set and +points to the root of the vcpkg installation. ## Installing @@ -395,7 +391,7 @@ Halide's own CI infrastructure, or as escape hatches for third-party packagers. |-----------------------------|--------------------------------------------------------------------|------------------------------------------------------------------------------------------| | `Halide_CLANG_TIDY_BUILD` | `OFF` | Used internally to generate fake compile jobs for runtime files when running clang-tidy. | | `Halide_CCACHE_BUILD` | `OFF` | Use ccache with Halide-recommended settings to accelerate rebuilds. | -| `Halide_CCACHE_PARAMS` | `CCACHE_CPP2=yes CCACHE_HASHDIR=yes CCACHE_SLOPPINESS=pch_defines` | Options to pass to `ccache` when using `Halide_CCACHE_BUILD`. | +| `Halide_CCACHE_PARAMS` | `CCACHE_CPP2=yes CCACHE_HASHDIR=yes CCACHE_SLOPPINESS=pch_defines` | Options to pass to `ccache` when using `Halide_CCACHE_BUILD`. | | `Halide_SOVERSION_OVERRIDE` | `${Halide_VERSION_MAJOR}` | Override the SOVERSION for libHalide. Expects a positive integer (i.e. not a version). | The following options are only available when building Halide directly, ie. not @@ -567,7 +563,7 @@ No matter how you intend to use Halide, you will need some basic CMake boilerplate. ```cmake -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(HalideExample) set(CMAKE_CXX_STANDARD 17) # or newer @@ -702,8 +698,7 @@ autoscheduler: ```cmake add_halide_library(my_second_generator FROM my_generators - AUTOSCHEDULER Halide::Adams2019 - PARAMS auto_schedule=true) + AUTOSCHEDULER Halide::Adams2019) ``` ### RunGenMain @@ -805,12 +800,12 @@ Halide defines the following targets that are available to users: The following targets are not guaranteed to be available: -| Imported target | Description | -|-------------------------|------------------------------------------------------------------------------------------------------------------------------------------| -| `Halide::Python` | this is a Python 3 module that can be referenced as `$` when setting up Python tests or the like from CMake. | -| `Halide::Adams19` | the Adams et.al. 2019 autoscheduler (no GPU support) | -| `Halide::Li18` | the Li et.al. 2018 gradient autoscheduler (limited GPU support) | -| `Halide::Mullapudi2016` | the Mullapudi et.al. 2016 autoscheduler (no GPU support) | +| Imported target | Description | +|-------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `Halide::Python` | this is a Python 3 package that can be referenced as `$/..` when setting up `PYTHONPATH` for Python tests or the like from CMake. | +| `Halide::Adams19` | the Adams et.al. 2019 autoscheduler (no GPU support) | +| `Halide::Li18` | the Li et.al. 2018 gradient autoscheduler (limited GPU support) | +| `Halide::Mullapudi2016` | the Mullapudi et.al. 2016 autoscheduler (no GPU support) | ### Functions @@ -883,9 +878,9 @@ being created. When `TARGETS` is empty and the `host` target would not cross-compile, then `host` will be used. Otherwise, `cmake` will be used and an author warning will be issued. -To set the default autoscheduler, set the `AUTOSCHEDULER` argument to a target +To use an autoscheduler, set the `AUTOSCHEDULER` argument to a target named like `Namespace::Scheduler`, for example `Halide::Adams19`. This will set -the `-s` flag on the generator command line to `Scheduler` and add the target to +the `autoscheduler` GeneratorParam on the generator command line to `Scheduler` and add the target to the list of plugins. Additional plugins can be loaded by setting the `PLUGINS` argument. If the argument to `AUTOSCHEDULER` does not contain `::` or it does not name a target, it will be passed to the `-s` flag verbatim. @@ -998,7 +993,7 @@ would call `add_halide_library` with no `TARGETS` option and set `FROM` equal to the name of the imported generator executable. Obviously, this is a significant increase in complexity over a typical CMake project. -This is very compatible with the `add_halide_generator` strategy above. +This is very compatible with the `add_halide_generator` strategy above. ### Use `ExternalProject` directly @@ -1035,7 +1030,7 @@ using [`install(FILES)`][install-files] and the # Contributing CMake code to Halide When contributing new CMake code to Halide, keep in mind that the minimum -version is 3.16. Therefore, it is possible (and indeed required) to use modern +version is 3.22. Therefore, it is possible (and indeed required) to use modern CMake best practices. Like any large and complex system with a dedication to preserving backwards diff --git a/apps/CMakeLists.txt b/apps/CMakeLists.txt index cc6a598d0b29..7fa1585a5ebe 100644 --- a/apps/CMakeLists.txt +++ b/apps/CMakeLists.txt @@ -2,56 +2,60 @@ # Test apps from the perspective of a consuming project. ## -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(Halide_apps) +enable_testing() + if (WIN32) option(ENABLE_APPS_HANNK "Build apps/hannk" OFF) else () option(ENABLE_APPS_HANNK "Build apps/hannk" ON) endif () -enable_testing() - -# add_subdirectory(HelloAndroid) # TODO(#5374): missing CMake build -# add_subdirectory(HelloAndroidCamera2) # TODO(#5374): missing CMake build -# add_subdirectory(HelloMatlab) # TODO(#5374): missing CMake build -# add_subdirectory(HelloPyTorch) # TODO(#5374): missing CMake build -# add_subdirectory(HelloWasm) # TODO(#5374): missing CMake build -# add_subdirectory(HelloiOS) # TODO(#5374): missing CMake build -# add_subdirectory(auto_viz) # TODO(#5374): missing CMake build -add_subdirectory(bgu) -add_subdirectory(bilateral_grid) -add_subdirectory(blur) -add_subdirectory(c_backend) -add_subdirectory(camera_pipe) -add_subdirectory(conv_layer) -add_subdirectory(cuda_mat_mul) -add_subdirectory(depthwise_separable_conv) -add_subdirectory(fft) -if (ENABLE_APPS_HANNK) - add_subdirectory(hannk) -endif () -add_subdirectory(harris) -# add_subdirectory(hexagon_benchmarks) # TODO(#5374): missing CMake build -# add_subdirectory(hexagon_dma) # TODO(#5374): missing CMake build -add_subdirectory(hist) -add_subdirectory(iir_blur) -add_subdirectory(interpolate) -add_subdirectory(lens_blur) -add_subdirectory(linear_algebra) -# add_subdirectory(linear_blur) # TODO(#5374): missing CMake build -add_subdirectory(local_laplacian) -add_subdirectory(max_filter) -add_subdirectory(nl_means) -# add_subdirectory(nn_ops) # TODO(#5374): missing CMake build -# add_subdirectory(onnx) # TODO(#5374): missing CMake build -# add_subdirectory(openglcompute) # TODO(#5374): missing CMake build -add_subdirectory(resize) -# add_subdirectory(resnet_50) # TODO(#5374): missing CMake build -# add_subdirectory(simd_op_check) # TODO(#5374): missing CMake build -add_subdirectory(stencil_chain) -add_subdirectory(unsharp) -add_subdirectory(wavelet) +function(add_app app_name) + string(TOUPPER "ENABLE_APPS_${app_name}" opt) + option(${opt} "Build apps/${app_name}" ON) + if (${opt}) + add_subdirectory(${app_name}) + endif () +endfunction() -add_subdirectory(random_pipeline) # resurrecting random_pipeline +# add_app(HelloAndroid) # TODO(#5374): missing CMake build +# add_app(HelloAndroidCamera2) # TODO(#5374): missing CMake build +# add_app(HelloPyTorch) # TODO(#5374): missing CMake build +# add_app(HelloWasm) # TODO(#5374): missing CMake build +# add_app(HelloiOS) # TODO(#5374): missing CMake build +# add_app(auto_viz) # TODO(#5374): missing CMake build +add_app(bgu) +add_app(bilateral_grid) +add_app(blur) +add_app(c_backend) +add_app(camera_pipe) +add_app(conv_layer) +add_app(cuda_mat_mul) +add_app(depthwise_separable_conv) +add_app(fft) +add_app(hannk) +add_app(harris) +# add_app(hexagon_benchmarks) # TODO(#5374): missing CMake build +# add_app(hexagon_dma) # TODO(#5374): missing CMake build +add_app(hist) +add_app(iir_blur) +add_app(interpolate) +add_app(lens_blur) +add_app(linear_algebra) +# add_app(linear_blur) # TODO(#5374): missing CMake build +add_app(local_laplacian) +add_app(max_filter) +add_app(nl_means) +# add_app(nn_ops) # TODO(#5374): missing CMake build +# add_app(onnx) # TODO(#5374): missing CMake build +# add_app(openglcompute) # TODO(#5374): missing CMake build +add_app(resize) +# add_app(resnet_50) # TODO(#5374): missing CMake build +# add_app(simd_op_check) # TODO(#5374): missing CMake build +add_app(stencil_chain) +add_app(unsharp) +add_app(wavelet) +add_app(random_pipeline) # resurrecting random_pipeline diff --git a/apps/CMakePresets.json b/apps/CMakePresets.json new file mode 100644 index 000000000000..788ffeb0b6e1 --- /dev/null +++ b/apps/CMakePresets.json @@ -0,0 +1,60 @@ +{ + "version": 3, + "cmakeMinimumRequired": { + "major": 3, + "minor": 22, + "patch": 0 + }, + "configurePresets": [ + { + "name": "default", + "hidden": true, + "binaryDir": "build/${presetName}", + "installDir": "install/${presetName}" + }, + { + "name": "ci", + "hidden": true, + "inherits": "default", + "toolchainFile": "${sourceDir}/../cmake/toolchain.${presetName}.cmake", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "RelWithDebInfo" + } + }, + { + "name": "linux-x64-asan", + "inherits": "ci", + "displayName": "ASAN (Linux x64)", + "description": "Build everything with ASAN enabled", + "cacheVariables": { + "LLVM_ROOT": "$penv{LLVM_ROOT}", + "ENABLE_APPS_BGU": "OFF" + } + } + ], + "buildPresets": [ + { + "name": "linux-x64-asan", + "configurePreset": "linux-x64-asan", + "displayName": "ASAN (Linux x64)", + "description": "Build everything with ASAN enabled", + "environment": { + "ASAN_OPTIONS": "detect_leaks=0:detect_container_overflow=0" + } + } + ], + "testPresets": [ + { + "name": "linux-x64-asan", + "configurePreset": "linux-x64-asan", + "displayName": "ASAN (Linux x64)", + "description": "Test everything with ASAN enabled", + "environment": { + "ASAN_OPTIONS": "detect_leaks=0:detect_container_overflow=0" + }, + "output": { + "outputOnFailure": true + } + } + ] +} diff --git a/apps/HelloMatlab/Makefile b/apps/HelloMatlab/Makefile deleted file mode 100644 index 5aca66819f36..000000000000 --- a/apps/HelloMatlab/Makefile +++ /dev/null @@ -1,9 +0,0 @@ -include ../support/Makefile.inc - -.PHONY: build test -build: - @echo Nothing to build - -test: - ./run_blur.sh - diff --git a/apps/HelloMatlab/iir_blur.cpp b/apps/HelloMatlab/iir_blur.cpp deleted file mode 100644 index 2246fd0639b9..000000000000 --- a/apps/HelloMatlab/iir_blur.cpp +++ /dev/null @@ -1,86 +0,0 @@ -// This file defines a generator for a first order IIR low pass filter -// for a 2D image. - -#include "Halide.h" - -using namespace Halide; -using namespace Halide::BoundaryConditions; - -Var x, y, c; - -// Defines a func to blur the columns of an input with a first order low -// pass IIR filter, followed by a transpose. -Func blur_cols_transpose(Func input, Expr height, Expr alpha) { - Func blur; - - // Pure definition: do nothing. - blur(x, y, c) = undef(); - // Update 0: set the top row of the result to the input. - blur(x, 0, c) = input(x, 0, c); - // Update 1: run the IIR filter down the columns. - RDom ry(1, height - 1); - blur(x, ry, c) = - (1 - alpha) * blur(x, ry - 1, c) + alpha * input(x, ry, c); - // Update 2: run the IIR blur up the columns. - Expr flip_ry = height - ry - 1; - blur(x, flip_ry, c) = - (1 - alpha) * blur(x, flip_ry + 1, c) + alpha * blur(x, flip_ry, c); - - // Transpose the blur. - Func transpose; - transpose(x, y, c) = blur(y, x, c); - - // Schedule: - // Split the transpose into tiles of rows. Parallelize over channels - // and strips (Halide supports nested parallelism). - Var xo, yo; - transpose.compute_root() - .tile(x, y, xo, yo, x, y, 8, 8) - .vectorize(x) - .parallel(yo) - .parallel(c); - - // Run the filter on each row of tiles (which corresponds to a strip of - // columns in the input). - blur.compute_at(transpose, yo); - - // Vectorize computations within the strips. - blur.update(0) - .vectorize(x); - blur.update(1) - .reorder(x, ry) - .vectorize(x); - blur.update(2) - .reorder(x, ry) - .vectorize(x); - - return transpose; -} - -class IirBlur : public Generator { -public: - // This is the input image: a 3D (color) image with 32 bit float - // pixels. - Input> input{"input"}; - // The filter coefficient, alpha is the weight of the input to the - // filter. - Input alpha{"alpha"}; - - Output> output{"output"}; - - void generate() { - Expr width = input.width(); - Expr height = input.height(); - - // First, blur the columns of the input. - Func blury_T = blur_cols_transpose(input, height, alpha); - - // Blur the columns again (the rows of the original). - Func blur = blur_cols_transpose(blury_T, width, alpha); - - // Scheduling is done inside blur_cols_transpose. - output(x, y, c) = blur(x, y, c); - } -}; - -HALIDE_REGISTER_GENERATOR(IirBlur, IirBlur) diff --git a/apps/HelloMatlab/run_blur.m b/apps/HelloMatlab/run_blur.m deleted file mode 100644 index d3993a2bce83..000000000000 --- a/apps/HelloMatlab/run_blur.m +++ /dev/null @@ -1,22 +0,0 @@ -% Add the path to mex_halide.m. -addpath(fullfile(getenv('HALIDE_DISTRIB_PATH'), 'tools')); - -% Build the mex library from the blur generator. -mex_halide('iir_blur.cpp', '-g', 'IirBlur'); - -% Load the input, create an output buffer of equal size. -input = cast(imread('../images/rgb.png'), 'single') / 255; -output = zeros(size(input), 'single'); - -% The blur filter coefficient. -alpha = 0.1; - -% Call the Halide pipeline. -for i = 1:10 - tic; - iir_blur(input, alpha, output); - toc; -end - -% Write the blurred image. -imwrite(cast(output * 255, 'uint8'), 'blurred.png'); diff --git a/apps/HelloMatlab/run_blur.sh b/apps/HelloMatlab/run_blur.sh deleted file mode 100755 index 2733b9332bc2..000000000000 --- a/apps/HelloMatlab/run_blur.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash - -# This script is run by the nightly tests to check that mex_halide works. - -command -v octave >/dev/null 2>&1 || { echo >&2 "Octave not found. Aborting."; exit 0; } - -if [[ $CXX == *"-m32"* ]]; then - echo "Not proceeding because Halide is compiled in 32-bit mode but octave is (likely) 64-bit" - exit 0 -fi - -rm -f blurred.png iir_blur.mex -octave run_blur.m - -if [ -f blurred.png ] -then - echo "Success!" - exit 0 -fi - -echo "Failed to produce blurred.png!" -exit 1 diff --git a/apps/HelloPyTorch/Makefile b/apps/HelloPyTorch/Makefile index 15dd231b99de..c05d2826a475 100644 --- a/apps/HelloPyTorch/Makefile +++ b/apps/HelloPyTorch/Makefile @@ -84,8 +84,7 @@ $(BIN)/%/add_float32.a: $(GENERATOR_BIN)/add.generator -f add_float32 \ -e static_library,c_header,pytorch_wrapper \ -o $(@D) \ - target=$* \ - auto_schedule=false + target=$* $(BIN)/%/add_halidegrad_float32.a: $(GENERATOR_BIN)/add.generator @mkdir -p $(@D) @@ -95,11 +94,10 @@ $(BIN)/%/add_halidegrad_float32.a: $(GENERATOR_BIN)/add.generator -f add_halidegrad_float32 \ -e static_library,c_header,pytorch_wrapper \ -p $(HALIDE_DISTRIB_PATH)/lib/libautoschedule_li2018.so \ - -s Li2018 \ -o $(@D) \ -d 1 \ target=$* \ - auto_schedule=true + autoscheduler=Li2018 $(BIN)/%/add_grad_float32.a: $(GENERATOR_BIN)/add.generator @mkdir -p $(@D) @@ -109,8 +107,7 @@ $(BIN)/%/add_grad_float32.a: $(GENERATOR_BIN)/add.generator -f add_grad_float32 \ -e static_library,c_header,pytorch_wrapper \ -o $(@D) \ - target=$* \ - auto_schedule=false + target=$* $(BIN)/%/add_float64.a: $(GENERATOR_BIN)/add.generator @mkdir -p $(@D) @@ -120,8 +117,7 @@ $(BIN)/%/add_float64.a: $(GENERATOR_BIN)/add.generator -f add_float64 \ -e static_library,c_header,pytorch_wrapper \ -o $(@D) \ - target=$* \ - auto_schedule=false + target=$* $(BIN)/%/add_halidegrad_float64.a: $(GENERATOR_BIN)/add.generator @mkdir -p $(@D) @@ -132,11 +128,10 @@ $(BIN)/%/add_halidegrad_float64.a: $(GENERATOR_BIN)/add.generator -e static_library,c_header,pytorch_wrapper \ -o $(@D) \ -p $(HALIDE_DISTRIB_PATH)/lib/libautoschedule_li2018.so \ - -s Li2018 \ target=$* \ -d 1 \ target=$* \ - auto_schedule=true + autoscheduler=Li2018 $(BIN)/%/add_grad_float64.a: $(GENERATOR_BIN)/add.generator @mkdir -p $(@D) @@ -146,8 +141,7 @@ $(BIN)/%/add_grad_float64.a: $(GENERATOR_BIN)/add.generator -f add_grad_float64 \ -e static_library,c_header,pytorch_wrapper \ -o $(@D) \ - target=$* \ - auto_schedule=false + target=$* # ----------------------------------------------------------------------------- @@ -160,8 +154,7 @@ $(BIN)/%/add_cuda_float32.a: $(GENERATOR_BIN)/add.generator -f add_cuda_float32 \ -e static_library,c_header,pytorch_wrapper \ -o $(@D) \ - target=$(CUDA_TARGET) \ - auto_schedule=false + target=$(CUDA_TARGET) $(BIN)/%/add_halidegrad_cuda_float32.a: $(GENERATOR_BIN)/add.generator @mkdir -p $(@D) @@ -172,10 +165,9 @@ $(BIN)/%/add_halidegrad_cuda_float32.a: $(GENERATOR_BIN)/add.generator -e static_library,c_header,pytorch_wrapper \ -o $(@D) \ -p $(HALIDE_DISTRIB_PATH)/lib/libautoschedule_li2018.so \ - -s Li2018 \ -d 1 \ target=$(CUDA_TARGET) \ - auto_schedule=true + autoscheduler=Li2018 $(BIN)/%/add_grad_cuda_float32.a: $(GENERATOR_BIN)/add.generator @mkdir -p $(@D) @@ -185,8 +177,7 @@ $(BIN)/%/add_grad_cuda_float32.a: $(GENERATOR_BIN)/add.generator -f add_grad_cuda_float32 \ -e static_library,c_header,pytorch_wrapper \ -o $(@D) \ - target=$(CUDA_TARGET) \ - auto_schedule=false + target=$(CUDA_TARGET) $(BIN)/%/add_cuda_float64.a: $(GENERATOR_BIN)/add.generator @mkdir -p $(@D) @@ -196,8 +187,7 @@ $(BIN)/%/add_cuda_float64.a: $(GENERATOR_BIN)/add.generator -f add_cuda_float64 \ -e static_library,c_header,pytorch_wrapper \ -o $(@D) \ - target=$(CUDA_TARGET) \ - auto_schedule=false + target=$(CUDA_TARGET) $(BIN)/%/add_halidegrad_cuda_float64.a: $(GENERATOR_BIN)/add.generator @mkdir -p $(@D) @@ -208,10 +198,9 @@ $(BIN)/%/add_halidegrad_cuda_float64.a: $(GENERATOR_BIN)/add.generator -e static_library,c_header,pytorch_wrapper \ -o $(@D) \ -p $(HALIDE_DISTRIB_PATH)/lib/libautoschedule_li2018.so \ - -s Li2018 \ -d 1 \ target=$(CUDA_TARGET) \ - auto_schedule=true + autoscheduler=Li2018 $(BIN)/%/add_grad_cuda_float64.a: $(GENERATOR_BIN)/add.generator @mkdir -p $(@D) @@ -221,8 +210,7 @@ $(BIN)/%/add_grad_cuda_float64.a: $(GENERATOR_BIN)/add.generator -f add_grad_cuda_float64 \ -e static_library,c_header,pytorch_wrapper \ -o $(@D) \ - target=$(CUDA_TARGET) \ - auto_schedule=false + target=$(CUDA_TARGET) # ----------------------------------------------------------------------------- diff --git a/apps/HelloPyTorch/src/add_generator.cpp b/apps/HelloPyTorch/src/add_generator.cpp index 8f2d8f4d6a81..ccfaa937d5e9 100644 --- a/apps/HelloPyTorch/src/add_generator.cpp +++ b/apps/HelloPyTorch/src/add_generator.cpp @@ -30,7 +30,7 @@ class AddGenerator : public Generator { output.set_estimates({{0, kEdge}, {0, kEdge}, {0, kEdge}, {0, kEdge}}); // Schedule - if (!auto_schedule) { + if (!using_autoscheduler()) { Var tx("tx"), xy("xy"), cn("cn"), allvars("allvars"); if (get_target().has_gpu_feature()) { output @@ -84,7 +84,7 @@ class AddGradGenerator : public Generator { d_input_b.set_estimates({{0, kEdge}, {0, kEdge}, {0, kEdge}, {0, kEdge}}); // Schedule - if (!auto_schedule) { + if (!using_autoscheduler()) { Var tx("tx"), xy("xy"), cn("cn"), allvars("allvars"); if (get_target().has_gpu_feature()) { diff --git a/apps/HelloWasm/CMakeLists.txt b/apps/HelloWasm/CMakeLists.txt index ad2470e1e9db..206639111e84 100644 --- a/apps/HelloWasm/CMakeLists.txt +++ b/apps/HelloWasm/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(HelloWasm) enable_testing() @@ -10,10 +10,7 @@ set(CMAKE_CXX_EXTENSIONS NO) find_package(Halide REQUIRED) set(halide_includes "$") -find_program(EMCC emcc HINTS "$ENV{EMSDK}/upstream/emscripten") -if (NOT EMCC) - message(FATAL_ERROR "Could not find emscripten/emcc!") -endif () +find_program(EMCC emcc REQUIRED HINTS "$ENV{EMSDK}/upstream/emscripten") configure_file(index.html index.html COPYONLY) diff --git a/apps/bgu/CMakeLists.txt b/apps/bgu/CMakeLists.txt index 60f79339ff6d..0a2f9eb82eec 100644 --- a/apps/bgu/CMakeLists.txt +++ b/apps/bgu/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(bgu) enable_testing() diff --git a/apps/bgu/Makefile b/apps/bgu/Makefile index 297ceaee90b0..8eb687ec064a 100644 --- a/apps/bgu/Makefile +++ b/apps/bgu/Makefile @@ -16,11 +16,11 @@ $(GENERATOR_BIN)/bgu.generator: bgu_generator.cpp $(GENERATOR_DEPS) $(BIN)/%/bgu.a: $(GENERATOR_BIN)/bgu.generator @mkdir -p $(@D) - $< -g bgu -f bgu -o $(BIN)/$* target=$*-no_runtime auto_schedule=false + $< -g bgu -f bgu -o $(BIN)/$* target=$*-no_runtime $(BIN)/%/bgu_auto_schedule.a: $(GENERATOR_BIN)/bgu.generator @mkdir -p $(@D) - $< -g bgu -f bgu_auto_schedule -o $(BIN)/$* target=$*-no_runtime auto_schedule=true + $< -g bgu -f bgu_auto_schedule -o $(BIN)/$* target=$*-no_runtime autoscheduler=Mullapudi2016 $(BIN)/%/runtime.a: $(GENERATOR_BIN)/bgu.generator @mkdir -p $(@D) diff --git a/apps/bgu/bgu_generator.cpp b/apps/bgu/bgu_generator.cpp index 054df3e52ba6..1b2cff5b1dc7 100644 --- a/apps/bgu/bgu_generator.cpp +++ b/apps/bgu/bgu_generator.cpp @@ -430,7 +430,7 @@ class BGU : public Generator { b(2, 2) += weighted_lambda * gain; // Now solve Ax = b - Matrix<3, 4> result = transpose(solve_symmetric(A, b, line, x, auto_schedule, get_target())); + Matrix<3, 4> result = transpose(solve_symmetric(A, b, line, x, using_autoscheduler(), get_target())); // Pack the resulting matrix into the output Func. line(x, y, z, c) = pack_channels(c, {result(0, 0), @@ -509,7 +509,7 @@ class BGU : public Generator { output = slice; // Schedule - if (!auto_schedule) { + if (!using_autoscheduler()) { if (!get_target().has_gpu_feature()) { // 7.09 ms on an Intel i9-9960X using 16 threads // diff --git a/apps/bilateral_grid/CMakeLists.txt b/apps/bilateral_grid/CMakeLists.txt index 51b166a64a77..2b32f0911755 100644 --- a/apps/bilateral_grid/CMakeLists.txt +++ b/apps/bilateral_grid/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(bilateral_grid) enable_testing() diff --git a/apps/bilateral_grid/Makefile b/apps/bilateral_grid/Makefile index 405d3e3c6782..11d79fbcd946 100644 --- a/apps/bilateral_grid/Makefile +++ b/apps/bilateral_grid/Makefile @@ -10,11 +10,11 @@ $(GENERATOR_BIN)/bilateral_grid.generator: bilateral_grid_generator.cpp $(GENERA $(BIN)/%/bilateral_grid.a: $(GENERATOR_BIN)/bilateral_grid.generator @mkdir -p $(@D) - $^ -g bilateral_grid -e $(GENERATOR_OUTPUTS) -o $(@D) -f bilateral_grid target=$* auto_schedule=false + $^ -g bilateral_grid -e $(GENERATOR_OUTPUTS) -o $(@D) -f bilateral_grid target=$* $(BIN)/%/bilateral_grid_auto_schedule.a: $(GENERATOR_BIN)/bilateral_grid.generator @mkdir -p $(@D) - $^ -g bilateral_grid -e $(GENERATOR_OUTPUTS) -o $(@D) -f bilateral_grid_auto_schedule target=$*-no_runtime auto_schedule=true + $^ -g bilateral_grid -e $(GENERATOR_OUTPUTS) -o $(@D) -f bilateral_grid_auto_schedule target=$*-no_runtime autoscheduler=Mullapudi2016 $(BIN)/%/filter: filter.cpp $(BIN)/%/bilateral_grid.a $(BIN)/%/bilateral_grid_auto_schedule.a @mkdir -p $(@D) diff --git a/apps/bilateral_grid/bilateral_grid_generator.cpp b/apps/bilateral_grid/bilateral_grid_generator.cpp index ede57459d5ab..b1e07fb15cdf 100644 --- a/apps/bilateral_grid/bilateral_grid_generator.cpp +++ b/apps/bilateral_grid/bilateral_grid_generator.cpp @@ -80,7 +80,7 @@ class BilateralGrid : public Halide::Generator { blury.set_estimate(z, 0, 12); bilateral_grid.set_estimates({{0, 1536}, {0, 2560}}); - if (auto_schedule) { + if (using_autoscheduler()) { // nothing } else if (get_target().has_gpu_feature()) { // 0.50ms on an RTX 2060 diff --git a/apps/blur/CMakeLists.txt b/apps/blur/CMakeLists.txt index ac46ec980995..f4c3f5324a84 100644 --- a/apps/blur/CMakeLists.txt +++ b/apps/blur/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(blur) enable_testing() diff --git a/apps/c_backend/CMakeLists.txt b/apps/c_backend/CMakeLists.txt index 92a18f679b2b..0d134532dda8 100644 --- a/apps/c_backend/CMakeLists.txt +++ b/apps/c_backend/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(c_backend) enable_testing() diff --git a/apps/camera_pipe/CMakeLists.txt b/apps/camera_pipe/CMakeLists.txt index 04a700bd847c..0d5a94e26614 100644 --- a/apps/camera_pipe/CMakeLists.txt +++ b/apps/camera_pipe/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(camera_pipe) enable_testing() diff --git a/apps/camera_pipe/Makefile b/apps/camera_pipe/Makefile index 38f984d2af3e..b86698cd36ed 100644 --- a/apps/camera_pipe/Makefile +++ b/apps/camera_pipe/Makefile @@ -12,11 +12,11 @@ $(GENERATOR_BIN)/camera_pipe.generator: camera_pipe_generator.cpp $(GENERATOR_DE $(BIN)/%/camera_pipe.a: $(GENERATOR_BIN)/camera_pipe.generator @mkdir -p $(@D) - $^ -g camera_pipe -e $(GENERATOR_OUTPUTS) -o $(@D) -f camera_pipe target=$* auto_schedule=false + $^ -g camera_pipe -e $(GENERATOR_OUTPUTS) -o $(@D) -f camera_pipe target=$* $(BIN)/%/camera_pipe_auto_schedule.a: $(GENERATOR_BIN)/camera_pipe.generator @mkdir -p $(@D) - $^ -g camera_pipe -e $(GENERATOR_OUTPUTS) -o $(@D) -f camera_pipe_auto_schedule target=$*-no_runtime auto_schedule=true + $^ -g camera_pipe -e $(GENERATOR_OUTPUTS) -o $(@D) -f camera_pipe_auto_schedule target=$*-no_runtime autoscheduler=Mullapudi2016 $(BIN)/%/process: process.cpp $(BIN)/%/camera_pipe.a $(BIN)/%/camera_pipe_auto_schedule.a @mkdir -p $(@D) diff --git a/apps/camera_pipe/camera_pipe_generator.cpp b/apps/camera_pipe/camera_pipe_generator.cpp index ec0323676cd4..06251f5691bb 100644 --- a/apps/camera_pipe/camera_pipe_generator.cpp +++ b/apps/camera_pipe/camera_pipe_generator.cpp @@ -154,7 +154,7 @@ class Demosaic : public Halide::Generator { void schedule() { Pipeline p(output); - if (auto_schedule) { + if (using_autoscheduler()) { // blank } else if (get_target().has_gpu_feature()) { Var xi, yi; @@ -270,7 +270,7 @@ Func CameraPipe::color_correct(Func input) { Expr val = (matrix_3200(x, y) * alpha + matrix_7000(x, y) * (1 - alpha)); matrix(x, y) = cast(val * 256.0f); // Q8.8 fixed point - if (!auto_schedule) { + if (!using_autoscheduler()) { matrix.compute_root(); if (get_target().has_gpu_feature()) { matrix.gpu_single_thread(); @@ -331,7 +331,7 @@ Func CameraPipe::apply_curve(Func input) { // makeLUT add guard band outside of (minRaw, maxRaw]: curve(x) = select(x <= minRaw, 0, select(x > maxRaw, 255, val)); - if (!auto_schedule) { + if (!using_autoscheduler()) { // It's a LUT, compute it once ahead of time. curve.compute_root(); if (get_target().has_gpu_feature()) { @@ -370,7 +370,7 @@ Func CameraPipe::sharpen(Func input) { // Convert the sharpening strength to 2.5 fixed point. This allows sharpening in the range [0, 4]. Func sharpen_strength_x32("sharpen_strength_x32"); sharpen_strength_x32() = u8_sat(sharpen_strength * 32); - if (!auto_schedule) { + if (!using_autoscheduler()) { sharpen_strength_x32.compute_root(); if (get_target().has_gpu_feature()) { sharpen_strength_x32.gpu_single_thread(); @@ -439,12 +439,12 @@ void CameraPipe::generate() { processed.set_estimates({{0, 2592}, {0, 1968}, {0, 3}}); // Schedule - if (auto_schedule) { + if (using_autoscheduler()) { // nothing } else if (get_target().has_gpu_feature()) { // We can generate slightly better code if we know the output is even-sized - if (!auto_schedule) { + if (!using_autoscheduler()) { // TODO: The autoscheduler really ought to be able to // accommodate bounds on the output Func. Expr out_width = processed.width(); diff --git a/apps/conv_layer/CMakeLists.txt b/apps/conv_layer/CMakeLists.txt index e3df023cac8f..94674097d290 100644 --- a/apps/conv_layer/CMakeLists.txt +++ b/apps/conv_layer/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(conv_layer) enable_testing() diff --git a/apps/conv_layer/Makefile b/apps/conv_layer/Makefile index 2ac64101691f..43db9f9ee70a 100644 --- a/apps/conv_layer/Makefile +++ b/apps/conv_layer/Makefile @@ -10,11 +10,11 @@ $(GENERATOR_BIN)/conv_layer.generator: conv_layer_generator.cpp $(GENERATOR_DEPS $(BIN)/%/conv_layer.a: $(GENERATOR_BIN)/conv_layer.generator @mkdir -p $(@D) - $^ -g conv_layer -e $(GENERATOR_OUTPUTS) -o $(@D) -f conv_layer target=$* auto_schedule=false + $^ -g conv_layer -e $(GENERATOR_OUTPUTS) -o $(@D) -f conv_layer target=$* $(BIN)/%/conv_layer_auto_schedule.a: $(GENERATOR_BIN)/conv_layer.generator @mkdir -p $(@D) - $^ -g conv_layer -e $(GENERATOR_OUTPUTS) -o $(@D) -f conv_layer_auto_schedule target=$*-no_runtime auto_schedule=true + $^ -g conv_layer -e $(GENERATOR_OUTPUTS) -o $(@D) -f conv_layer_auto_schedule target=$*-no_runtime autoscheduler=Mullapudi2016 $(BIN)/%/process: process.cpp $(BIN)/%/conv_layer.a $(BIN)/%/conv_layer_auto_schedule.a @mkdir -p $(@D) diff --git a/apps/conv_layer/conv_layer_generator.cpp b/apps/conv_layer/conv_layer_generator.cpp index 5b6ff1ee5e10..a27d367a076d 100644 --- a/apps/conv_layer/conv_layer_generator.cpp +++ b/apps/conv_layer/conv_layer_generator.cpp @@ -49,7 +49,7 @@ class ConvolutionLayer : public Halide::Generator { bias.dim(0).set_bounds(0, CO).set_stride(1); - if (auto_schedule) { + if (using_autoscheduler()) { input.dim(0).set_estimate(0, CI); input.dim(1).set_estimate(0, W + 2); input.dim(2).set_estimate(0, H + 2); diff --git a/apps/cuda_mat_mul/CMakeLists.txt b/apps/cuda_mat_mul/CMakeLists.txt index b8bc8ad3e6eb..352553ec048b 100644 --- a/apps/cuda_mat_mul/CMakeLists.txt +++ b/apps/cuda_mat_mul/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(cuda_mat_mul) # This just checks whether CUDA is available ahead of time to allow diff --git a/apps/depthwise_separable_conv/CMakeLists.txt b/apps/depthwise_separable_conv/CMakeLists.txt index 6040f85f5f6a..11e24f335d11 100644 --- a/apps/depthwise_separable_conv/CMakeLists.txt +++ b/apps/depthwise_separable_conv/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(depthwise_separable_conv) enable_testing() diff --git a/apps/depthwise_separable_conv/Makefile b/apps/depthwise_separable_conv/Makefile index def2146eb3f6..001e12444809 100644 --- a/apps/depthwise_separable_conv/Makefile +++ b/apps/depthwise_separable_conv/Makefile @@ -8,11 +8,11 @@ $(GENERATOR_BIN)/depthwise_separable_conv.generator: depthwise_separable_conv_ge $(BIN)/%/depthwise_separable_conv.a: $(GENERATOR_BIN)/depthwise_separable_conv.generator @mkdir -p $(@D) - $^ -g depthwise_separable_conv -e $(GENERATOR_OUTPUTS) -o $(@D) -f depthwise_separable_conv target=$* auto_schedule=false + $^ -g depthwise_separable_conv -e $(GENERATOR_OUTPUTS) -o $(@D) -f depthwise_separable_conv target=$* $(BIN)/%/depthwise_separable_conv_auto_schedule.a: $(GENERATOR_BIN)/depthwise_separable_conv.generator @mkdir -p $(@D) - $^ -g depthwise_separable_conv -e $(GENERATOR_OUTPUTS) -o $(@D) -f depthwise_separable_conv_auto_schedule target=$*-no_runtime auto_schedule=true + $^ -g depthwise_separable_conv -e $(GENERATOR_OUTPUTS) -o $(@D) -f depthwise_separable_conv_auto_schedule target=$*-no_runtime autoscheduler=Mullapudi2016 $(BIN)/%/process: process.cpp $(BIN)/%/depthwise_separable_conv.a $(BIN)/%/depthwise_separable_conv_auto_schedule.a @-mkdir -p $(BIN) diff --git a/apps/depthwise_separable_conv/depthwise_separable_conv_generator.cpp b/apps/depthwise_separable_conv/depthwise_separable_conv_generator.cpp index d560a8bea376..ba230ee03653 100644 --- a/apps/depthwise_separable_conv/depthwise_separable_conv_generator.cpp +++ b/apps/depthwise_separable_conv/depthwise_separable_conv_generator.cpp @@ -74,7 +74,7 @@ class DepthwiseSeparableConvolution : public Generator>") diff --git a/apps/fft/fft.cpp b/apps/fft/fft.cpp index 79382129c763..862b3f3e81e5 100644 --- a/apps/fft/fft.cpp +++ b/apps/fft/fft.cpp @@ -107,7 +107,7 @@ ComplexExpr mul(ComplexExpr a, float re_b, float im_b) { // Specializations for some small DFTs of the first dimension of a // Func f. ComplexFunc dft2(ComplexFunc f, const string &prefix) { - Type type = f.output_types()[0]; + Type type = f.types()[0]; ComplexFunc F(prefix + "X2"); F(f.args()) = undef_z(type); @@ -122,7 +122,7 @@ ComplexFunc dft2(ComplexFunc f, const string &prefix) { } ComplexFunc dft4(ComplexFunc f, int sign, const string &prefix) { - Type type = f.output_types()[0]; + Type type = f.types()[0]; ComplexFunc F(prefix + "X4"); F(f.args()) = undef_z(type); @@ -156,7 +156,7 @@ ComplexFunc dft6(ComplexFunc f, int sign, const string &prefix) { ComplexExpr W2_3(re_W1_3, -im_W1_3); ComplexExpr W4_3 = W1_3; - Type type = f.output_types()[0]; + Type type = f.types()[0]; ComplexFunc F(prefix + "X8"); F(f.args()) = undef_z(type); @@ -187,7 +187,7 @@ ComplexFunc dft6(ComplexFunc f, int sign, const string &prefix) { ComplexFunc dft8(ComplexFunc f, int sign, const string &prefix) { const float sqrt2_2 = 0.70710678f; - Type type = f.output_types()[0]; + Type type = f.types()[0]; ComplexFunc F(prefix + "X8"); F(f.args()) = undef_z(type); @@ -346,7 +346,7 @@ ComplexFunc fft_dim1(ComplexFunc x, // The vector width is the least common multiple of the previous vector // width and the natural vector size for this stage. - vector_width = lcm(vector_width, target.natural_vector_size(v.output_types()[0])); + vector_width = lcm(vector_width, target.natural_vector_size(v.types()[0])); // Compute the R point DFT of the subtransform. ComplexFunc V = dft1d_c2c(v, R, sign, prefix); @@ -355,7 +355,7 @@ ComplexFunc fft_dim1(ComplexFunc x, // pass. Since the pure stage is undef, we explicitly generate the // arg list (because we can't use placeholders in an undef // definition). - exchange(A({n0, n1}, args)) = undef_z(V.output_types()[0]); + exchange(A({n0, n1}, args)) = undef_z(V.types()[0]); RDom rs(0, R, 0, N / R); r_ = rs.x; @@ -444,7 +444,7 @@ std::pair tiled_transpose(FuncType f, int max_tile_size, } const int tile_size = - std::min(max_tile_size, target.natural_vector_size(f.output_types()[0])); + std::min(max_tile_size, target.natural_vector_size(f.types()[0])); vector args = f.args(); Var x(args[0]), y(args[1]); @@ -685,7 +685,7 @@ ComplexFunc fft2d_r2c(Func r, int N0 = product(R0); int N1 = product(R1); - const int natural_vector_size = target.natural_vector_size(r.output_types()[0]); + const int natural_vector_size = target.natural_vector_size(r.types()[0]); // If this FFT is small, the logic related to zipping and unzipping // the FFT may be expensive compared to just brute forcing with a complex @@ -705,7 +705,7 @@ ComplexFunc fft2d_r2c(Func r, result(A({n0, n1}, args)) = dft(A({n0, n1}, args)); result.bound(n0, 0, N0); result.bound(n1, 0, (N1 + 1) / 2 + 1); - result.vectorize(n0, std::min(N0, target.natural_vector_size(result.output_types()[0]))); + result.vectorize(n0, std::min(N0, target.natural_vector_size(result.types()[0]))); dft.compute_at(result, outer); return result; } @@ -731,7 +731,7 @@ ComplexFunc fft2d_r2c(Func r, ComplexFunc zipped(prefix + "zipped"); int zip_width = desc.vector_width; if (zip_width <= 0) { - zip_width = target.natural_vector_size(r.output_types()[0]); + zip_width = target.natural_vector_size(r.types()[0]); } // Ensure the zip width divides the zipped extent. zip_width = gcd(zip_width, N0 / 2); @@ -911,7 +911,7 @@ Func fft2d_c2r(ComplexFunc c, // If this FFT is small, the logic related to zipping and unzipping // the FFT may be expensive compared to just brute forcing with a complex // FFT. - const int natural_vector_size = target.natural_vector_size(c.output_types()[0]); + const int natural_vector_size = target.natural_vector_size(c.types()[0]); bool skip_zip = N0 < natural_vector_size * 2; @@ -967,7 +967,7 @@ Func fft2d_c2r(ComplexFunc c, // The vector width of the zipping performed below. int zip_width = desc.vector_width; if (zip_width <= 0) { - zip_width = gcd(target.natural_vector_size(dft0T.output_types()[0]), N1 / 2); + zip_width = gcd(target.natural_vector_size(dft0T.types()[0]), N1 / 2); } // transpose so we can take the DFT of the columns again. diff --git a/apps/hannk/CMakeLists.txt b/apps/hannk/CMakeLists.txt index 8a36ab583b16..741a828d02e7 100644 --- a/apps/hannk/CMakeLists.txt +++ b/apps/hannk/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(hannk) # We need to set this for some of the subprojects pulled in by TFLite (eg flatbuffers) diff --git a/apps/hannk/README.md b/apps/hannk/README.md index 3f5156a55360..dbcd354e8e46 100644 --- a/apps/hannk/README.md +++ b/apps/hannk/README.md @@ -6,7 +6,7 @@ There are several front ends for the interpreter: - Direct API This app is a work in progress. Currently, only quantized uint8 networks are supported. -All of the [TensorFlow hosted models](https://www.tensorflow.org/lite/guide/hosted_models) +All of the [TensorFlow hosted models](https://tfhub.dev/s?deployment-format=lite) are working and producing good performance. ### Benchmarks diff --git a/apps/hannk/cmake/superbuild/CMakeLists.txt b/apps/hannk/cmake/superbuild/CMakeLists.txt index d5eeaf0a11c0..1ae17705f3cd 100644 --- a/apps/hannk/cmake/superbuild/CMakeLists.txt +++ b/apps/hannk/cmake/superbuild/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16...3.21) +cmake_minimum_required(VERSION 3.22...3.23) project(hannk_superbuild LANGUAGES NONE) ## diff --git a/apps/harris/CMakeLists.txt b/apps/harris/CMakeLists.txt index 173d4bf71f80..135129be8752 100644 --- a/apps/harris/CMakeLists.txt +++ b/apps/harris/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(harris) enable_testing() diff --git a/apps/harris/Makefile b/apps/harris/Makefile index 713c11d0c2c7..d99a72591d38 100644 --- a/apps/harris/Makefile +++ b/apps/harris/Makefile @@ -10,11 +10,11 @@ $(GENERATOR_BIN)/harris.generator: harris_generator.cpp $(GENERATOR_DEPS) $(BIN)/%/harris.a: $(GENERATOR_BIN)/harris.generator @mkdir -p $(@D) - $< -g harris -f harris -o $(BIN)/$* target=$*-no_runtime auto_schedule=false + $< -g harris -f harris -o $(BIN)/$* target=$*-no_runtime $(BIN)/%/harris_auto_schedule.a: $(GENERATOR_BIN)/harris.generator @mkdir -p $(@D) - $< -g harris -f harris_auto_schedule -o $(BIN)/$* target=$*-no_runtime auto_schedule=true + $< -g harris -f harris_auto_schedule -o $(BIN)/$* target=$*-no_runtime autoscheduler=Mullapudi2016 $(BIN)/%/runtime.a: $(GENERATOR_BIN)/harris.generator @mkdir -p $(@D) diff --git a/apps/harris/harris_generator.cpp b/apps/harris/harris_generator.cpp index feb16d1d7170..69cf8c05c68c 100644 --- a/apps/harris/harris_generator.cpp +++ b/apps/harris/harris_generator.cpp @@ -72,7 +72,7 @@ class Harris : public Halide::Generator { } // Schedule - if (!auto_schedule) { + if (!using_autoscheduler()) { Var xi("xi"), yi("yi"); if (get_target().has_gpu_feature()) { // 0.253ms on a 2060 RTX diff --git a/apps/hist/CMakeLists.txt b/apps/hist/CMakeLists.txt index aa8a532d0558..7aebca4984ec 100644 --- a/apps/hist/CMakeLists.txt +++ b/apps/hist/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(hist) enable_testing() diff --git a/apps/hist/Makefile b/apps/hist/Makefile index 5f4faa1b835a..b0843bda1fb0 100644 --- a/apps/hist/Makefile +++ b/apps/hist/Makefile @@ -12,11 +12,11 @@ $(GENERATOR_BIN)/hist.generator: hist_generator.cpp $(GENERATOR_DEPS) $(BIN)/%/hist.a: $(GENERATOR_BIN)/hist.generator @mkdir -p $(@D) - $< -g hist -f hist -o $(BIN)/$* target=$*-no_runtime auto_schedule=false + $< -g hist -f hist -o $(BIN)/$* target=$*-no_runtime $(BIN)/%/hist_auto_schedule.a: $(GENERATOR_BIN)/hist.generator @mkdir -p $(@D) - $< -g hist -f hist_auto_schedule -o $(BIN)/$* target=$*-no_runtime auto_schedule=true + $< -g hist -f hist_auto_schedule -o $(BIN)/$* target=$*-no_runtime autoscheduler=Mullapudi2016 $(BIN)/%/runtime.a: $(GENERATOR_BIN)/hist.generator @mkdir -p $(@D) diff --git a/apps/hist/hist_generator.cpp b/apps/hist/hist_generator.cpp index e3d5de7f5737..32d86d3d0186 100644 --- a/apps/hist/hist_generator.cpp +++ b/apps/hist/hist_generator.cpp @@ -64,7 +64,7 @@ class Hist : public Halide::Generator { } // Schedule - if (!auto_schedule) { + if (!using_autoscheduler()) { cdf.bound(x, 0, 256); Var xi("xi"), yi("yi"); diff --git a/apps/iir_blur/CMakeLists.txt b/apps/iir_blur/CMakeLists.txt index 44ba79e5d41a..6e61e7ecd0ee 100644 --- a/apps/iir_blur/CMakeLists.txt +++ b/apps/iir_blur/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(iir_blur) enable_testing() diff --git a/apps/iir_blur/Makefile b/apps/iir_blur/Makefile index 8c9983c8fa14..49104b3e5fa3 100644 --- a/apps/iir_blur/Makefile +++ b/apps/iir_blur/Makefile @@ -10,11 +10,11 @@ $(GENERATOR_BIN)/iir_blur.generator: iir_blur_generator.cpp $(GENERATOR_DEPS) $(BIN)/%/iir_blur.a: $(GENERATOR_BIN)/iir_blur.generator @mkdir -p $(@D) - $< -g iir_blur -f iir_blur -o $(BIN)/$* target=$*-no_runtime auto_schedule=false + $< -g iir_blur -f iir_blur -o $(BIN)/$* target=$*-no_runtime $(BIN)/%/iir_blur_auto_schedule.a: $(GENERATOR_BIN)/iir_blur.generator @mkdir -p $(@D) - $< -g iir_blur -f iir_blur_auto_schedule -o $(BIN)/$* target=$*-no_runtime auto_schedule=true + $< -g iir_blur -f iir_blur_auto_schedule -o $(BIN)/$* target=$*-no_runtime autoscheduler=Mullapudi2016 $(BIN)/%/runtime.a: $(GENERATOR_BIN)/iir_blur.generator @mkdir -p $(@D) diff --git a/apps/iir_blur/iir_blur_generator.cpp b/apps/iir_blur/iir_blur_generator.cpp index 59ef065e79e6..1aeb3e0d1a5f 100644 --- a/apps/iir_blur/iir_blur_generator.cpp +++ b/apps/iir_blur/iir_blur_generator.cpp @@ -145,10 +145,10 @@ class IirBlur : public Generator { Expr height = input.height(); // First, blur the columns of the input. - Func blury_T = blur_cols_transpose(input, height, alpha, auto_schedule, get_target()); + Func blury_T = blur_cols_transpose(input, height, alpha, using_autoscheduler(), get_target()); // Blur the columns again (the rows of the original). - Func blur = blur_cols_transpose(blury_T, width, alpha, auto_schedule, get_target()); + Func blur = blur_cols_transpose(blury_T, width, alpha, using_autoscheduler(), get_target()); // Scheduling is done inside blur_cols_transpose. output = blur; diff --git a/apps/interpolate/CMakeLists.txt b/apps/interpolate/CMakeLists.txt index 49addcf16dc5..d723ac3b35da 100644 --- a/apps/interpolate/CMakeLists.txt +++ b/apps/interpolate/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(interpolate) enable_testing() diff --git a/apps/interpolate/Makefile b/apps/interpolate/Makefile index 8e55e16a1283..95c165b533ee 100644 --- a/apps/interpolate/Makefile +++ b/apps/interpolate/Makefile @@ -12,11 +12,11 @@ $(GENERATOR_BIN)/interpolate.generator: interpolate_generator.cpp $(GENERATOR_DE $(BIN)/%/interpolate.a: $(GENERATOR_BIN)/interpolate.generator @mkdir -p $(@D) - $< -g interpolate -e $(GENERATOR_OUTPUTS) -f interpolate -o $(BIN)/$* target=$*-no_runtime auto_schedule=false + $< -g interpolate -e $(GENERATOR_OUTPUTS) -f interpolate -o $(BIN)/$* target=$*-no_runtime $(BIN)/%/interpolate_auto_schedule.a: $(GENERATOR_BIN)/interpolate.generator @mkdir -p $(@D) - $< -g interpolate -e $(GENERATOR_OUTPUTS) -f interpolate_auto_schedule -o $(BIN)/$* target=$*-no_runtime auto_schedule=true + $< -g interpolate -e $(GENERATOR_OUTPUTS) -f interpolate_auto_schedule -o $(BIN)/$* target=$*-no_runtime autoscheduler=Mullapudi2016 $(BIN)/%/runtime.a: $(GENERATOR_BIN)/interpolate.generator @mkdir -p $(@D) diff --git a/apps/interpolate/interpolate_generator.cpp b/apps/interpolate/interpolate_generator.cpp index 58d6d65374eb..1e4026b9ef87 100644 --- a/apps/interpolate/interpolate_generator.cpp +++ b/apps/interpolate/interpolate_generator.cpp @@ -72,7 +72,7 @@ class Interpolate : public Halide::Generator { normalize(x, y, c) = interpolated[0](x, y, c) / interpolated[0](x, y, 3); // Schedule - if (auto_schedule) { + if (using_autoscheduler()) { output = normalize; } else { // 0.86ms on a 2060 RTX diff --git a/apps/lens_blur/CMakeLists.txt b/apps/lens_blur/CMakeLists.txt index 194cccd37daf..dcd9a70e4ac7 100644 --- a/apps/lens_blur/CMakeLists.txt +++ b/apps/lens_blur/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(lens_blur) enable_testing() diff --git a/apps/lens_blur/Makefile b/apps/lens_blur/Makefile index 8ede6b797ffe..c5c424c82edf 100644 --- a/apps/lens_blur/Makefile +++ b/apps/lens_blur/Makefile @@ -11,11 +11,11 @@ $(GENERATOR_BIN)/lens_blur.generator: lens_blur_generator.cpp $(GENERATOR_DEPS) $(BIN)/%/lens_blur.a: $(GENERATOR_BIN)/lens_blur.generator @mkdir -p $(@D) - $^ -g lens_blur -e $(GENERATOR_OUTPUTS) -o $(@D) -f lens_blur target=$* auto_schedule=false + $^ -g lens_blur -e $(GENERATOR_OUTPUTS) -o $(@D) -f lens_blur target=$* $(BIN)/%/lens_blur_auto_schedule.a: $(GENERATOR_BIN)/lens_blur.generator @mkdir -p $(@D) - $^ -g lens_blur -e $(GENERATOR_OUTPUTS) -o $(@D) -f lens_blur_auto_schedule target=$*-no_runtime auto_schedule=true + $^ -g lens_blur -e $(GENERATOR_OUTPUTS) -o $(@D) -f lens_blur_auto_schedule target=$*-no_runtime autoscheduler=Mullapudi2016 $(BIN)/%/process: process.cpp $(BIN)/%/lens_blur.a $(BIN)/%/lens_blur_auto_schedule.a @mkdir -p $(@D) diff --git a/apps/lens_blur/lens_blur_generator.cpp b/apps/lens_blur/lens_blur_generator.cpp index 52fad46cb82b..14aa92c876f2 100644 --- a/apps/lens_blur/lens_blur_generator.cpp +++ b/apps/lens_blur/lens_blur_generator.cpp @@ -166,7 +166,7 @@ class LensBlur : public Halide::Generator { final.set_estimates({{0, 192}, {0, 320}, {0, 3}}); /* THE SCHEDULE */ - if (auto_schedule) { + if (using_autoscheduler()) { // nothing } else if (get_target().has_gpu_feature()) { // Manual GPU schedule diff --git a/apps/linear_algebra/CMakeLists.txt b/apps/linear_algebra/CMakeLists.txt index a9def55c6667..adbe63b91df4 100644 --- a/apps/linear_algebra/CMakeLists.txt +++ b/apps/linear_algebra/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(linear_algebra) enable_testing() @@ -12,42 +12,47 @@ set(CMAKE_CXX_EXTENSIONS NO) find_package(Halide REQUIRED) # Find BLAS-es -set(DEFAULT_BLAS "") -set(BLAS_TARGETS "") -set(BLAS_VENDORS OpenBLAS ATLAS Apple Generic) - -# ATLAS is weird and has extra requirements -find_library(CBLAS_LIBRARY cblas) -set(ATLAS_EXTRA_LIBS ${CBLAS_LIBRARY}) +set(found_blases "") +set(known_vendors OpenBLAS ATLAS Apple Intel10_64_dyn Generic) message(STATUS "Checking for available CBLAS implementations") -foreach (BLA_VENDOR IN LISTS BLAS_VENDORS) +foreach (BLA_VENDOR IN LISTS known_vendors) find_package(BLAS QUIET) - if (NOT BLAS_FOUND - OR ("${BLA_VENDOR}" STREQUAL "ATLAS" AND NOT CBLAS_LIBRARY) - OR ("${BLA_VENDOR}" STREQUAL "Generic" AND BLAS_TARGETS)) - message(STATUS "${BLA_VENDOR}: Missing") - else () - list(APPEND BLAS_LIBRARIES ${${BLA_VENDOR}_EXTRA_LIBS}) - - message(STATUS "${BLA_VENDOR}: Found ${BLAS_LIBRARIES}") - add_library(BLAS_${BLA_VENDOR} INTERFACE) - add_library(${BLA_VENDOR}::${BLA_VENDOR} ALIAS BLAS_${BLA_VENDOR}) - target_link_libraries(BLAS_${BLA_VENDOR} INTERFACE ${BLAS_LIBRARIES}) - target_link_options(BLAS_${BLA_VENDOR} INTERFACE ${BLAS_LINKER_FLAGS}) - target_include_directories(BLAS_${BLA_VENDOR} SYSTEM INTERFACE include) # Use CBlas header in our own tree. + # Fail early if not found + if (NOT BLAS_FOUND) + message(STATUS "${BLA_VENDOR}: Missing") + continue() + endif () - if (NOT DEFAULT_BLAS) - set(DEFAULT_BLAS ${BLA_VENDOR}::${BLA_VENDOR}) + # ATLAS is weird and has extra requirements + if (BLA_VENDOR STREQUAL "ATLAS") + find_library(CBLAS_LIBRARY cblas) + if (NOT CBLAS_LIBRARY) + message(STATUS "${BLA_VENDOR}: Missing dependency on CBLAS (hint: set CBLAS_LIBRARY)") + continue() endif () + list(APPEND BLAS_LIBRARIES "${CBLAS_LIBRARY}") + endif () - list(APPEND BLAS_TARGETS ${BLA_VENDOR}) + # Don't use "Generic" BLAS if any good BLAS is available. + if (BLA_VENDOR STREQUAL "Generic" AND found_blases) + message(STATUS "${BLA_VENDOR}: Not considered") + continue() endif () + + message(STATUS "${BLA_VENDOR}: Found ${BLAS_LIBRARIES}") + + add_library(BLAS::${BLA_VENDOR} INTERFACE IMPORTED) + target_link_libraries(BLAS::${BLA_VENDOR} INTERFACE ${BLAS_LIBRARIES}) + target_link_options(BLAS::${BLA_VENDOR} INTERFACE ${BLAS_LINKER_FLAGS}) + target_include_directories(BLAS::${BLA_VENDOR} INTERFACE "${CMAKE_CURRENT_SOURCE_DIR}/include") # Use CBlas header in our own tree. + + list(APPEND found_blases ${BLA_VENDOR}) endforeach () -if (NOT BLAS_TARGETS) - message(FATAL_ERROR "Could not find any BLAS libraries! Searched among ${BLAS_VENDORS}") +if (NOT found_blases) + message(FATAL_ERROR "Could not find any BLAS libraries! Searched among ${known_vendors}") endif () # Load in the rest of the project. diff --git a/apps/linear_algebra/Makefile b/apps/linear_algebra/Makefile index 3d8b561aefe2..dff6baf1a830 100644 --- a/apps/linear_algebra/Makefile +++ b/apps/linear_algebra/Makefile @@ -77,6 +77,16 @@ BENCHMARKS = \ all: build make run_benchmarks +# This is a hack: disable this test when compiling 32-bit systems, as it's hard to find the right 32-bit versions +# of these libraries on 64-bit hosts. Can't rely on HL_TARGET because it might be 'host' even for cross-compiling. +# Look instead for `-m32` being passed to CXX, which is the cross-compiling flag we use. This is regrettable +# but expedient. (Note that CMake is able to find this correctly, and so we have test coverage there; this is +# simply not worth debugging as an edge case at the moment.) +ifneq (,$(findstring -m32,$(CXX))) +build: + @echo linear_algebra not support using Make on 32-bit systems: skipping linear_algebra tests... +test: build +else ifneq ("$(wildcard /usr/include/cblas.h /usr/include/*/cblas.h)","") build: $(BENCHMARKS) $(BIN)/test_halide_blas test: $(BIN)/test_halide_blas @@ -86,6 +96,7 @@ build: @echo /usr/include/cblas.h not found: skipping linear_algebra tests... test: build endif +endif clean: rm -rf $(BIN) diff --git a/apps/linear_algebra/benchmarks/CMakeLists.txt b/apps/linear_algebra/benchmarks/CMakeLists.txt index ab76d880c017..e8e0a470c47e 100644 --- a/apps/linear_algebra/benchmarks/CMakeLists.txt +++ b/apps/linear_algebra/benchmarks/CMakeLists.txt @@ -1,7 +1,7 @@ add_executable(halide_benchmarks halide_benchmarks.cpp) target_compile_definitions(halide_benchmarks PRIVATE ENABLE_FTZ_DAZ) target_link_libraries(halide_benchmarks PRIVATE halide_blas Halide::Tools) -set(BENCHMARK_TARGETS halide_benchmarks) +set(benchmark_targets halide_benchmarks) find_package(Eigen3 QUIET) set(Eigen3 Eigen3::Eigen) @@ -18,39 +18,41 @@ if (TARGET ${Eigen3}) add_executable(eigen_benchmarks eigen_benchmarks.cpp) target_compile_definitions(eigen_benchmarks PRIVATE EIGEN_DONT_PARALLELIZE ENABLE_FTZ_DAZ) target_link_libraries(eigen_benchmarks PRIVATE ${Eigen3} Halide::Tools) - list(APPEND BENCHMARK_TARGETS eigen_benchmarks) + list(APPEND benchmark_targets eigen_benchmarks) message(STATUS "Eigen3: Found") else () message(STATUS "Eigen3: Missing") endif () -foreach (BLAS_TARGET IN LISTS BLAS_TARGETS) - set(TARGET ${BLAS_TARGET}_benchmarks) - add_executable(${TARGET} cblas_benchmarks.cpp) - target_compile_definitions(${TARGET} PRIVATE "BLAS_NAME=\"${BLAS_TARGET}\"") - target_link_libraries(${TARGET} PRIVATE ${BLAS_TARGET}::${BLAS_TARGET} Halide::Tools) - list(APPEND BENCHMARK_TARGETS ${TARGET}) +foreach (blas IN LISTS FOUND_BLASES) + set(blas_benchmarks "${blas}_benchmarks") + add_executable(${blas_benchmarks} cblas_benchmarks.cpp) + target_compile_definitions(${blas_benchmarks} PRIVATE "BLAS_NAME=\"${blas}\"") + target_link_libraries(${blas_benchmarks} PRIVATE BLAS::${blas} Halide::Tools) + list(APPEND benchmark_targets ${blas_benchmarks}) endforeach () # Large powers of two are a pathological case for the cache, so avoid # them for the benchmarks. -set(BLAS_LEVELS L1 L2 L3) -list(APPEND BENCHMARK_SIZES 64 128 256 512 1280 2560) -list(APPEND L1_BENCHMARKS scopy dcopy sscal dscal saxpy daxpy sdot ddot sasum dasum) -list(APPEND L2_BENCHMARKS sgemv_notrans dgemv_notrans sgemv_trans dgemv_trans sger dger) -list(APPEND L3_BENCHMARKS sgemm_notrans dgemm_notrans sgemm_transA dgemm_transA sgemm_transB dgemm_transB sgemm_transAB dgemm_transAB) - -foreach (TARGET IN LISTS BENCHMARK_TARGETS) - string(REPLACE "_benchmarks" "" BLA_VENDOR "${TARGET}") - foreach (LEVEL IN LISTS BLAS_LEVELS) - foreach (FUNC IN LISTS ${LEVEL}_BENCHMARKS) - foreach (SIZE IN LISTS BENCHMARK_SIZES) - set(TEST_NAME ${BLA_VENDOR}_${FUNC}_${SIZE}) - add_test(NAME ${TEST_NAME} - COMMAND ${TARGET} ${FUNC} ${SIZE}) - set_tests_properties("${TEST_NAME}" PROPERTIES - LABELS "linear_algebra;${BLA_VENDOR};${LEVEL};slow_tests" - PASS_REGULAR_EXPRESSION "${FUNC}[ \t]+${SIZE}" +set(blas_levels L1 L2 L3) +list(APPEND benchmark_sizes 64 128 256 512 1280 2560) +list(APPEND L1_functions scopy dcopy sscal dscal saxpy daxpy sdot ddot sasum dasum) +list(APPEND L2_functions sgemv_notrans dgemv_notrans sgemv_trans dgemv_trans sger dger) +list(APPEND L3_functions sgemm_notrans dgemm_notrans sgemm_transA dgemm_transA sgemm_transB dgemm_transB sgemm_transAB dgemm_transAB) + +foreach (benchmark IN LISTS benchmark_targets) + string(REPLACE "_benchmarks" "" vendor "${benchmark}") + foreach (level IN LISTS blas_levels) + foreach (func IN LISTS ${level}_functions) + foreach (size IN LISTS benchmark_sizes) + set(test_name ${vendor}_${func}_${size}) + + add_test(NAME ${test_name} + COMMAND ${benchmark} ${func} ${size}) + + set_tests_properties("${test_name}" PROPERTIES + LABELS "linear_algebra;${vendor};${level};slow_tests" + PASS_REGULAR_EXPRESSION "${func}[ \t]+${size}" SKIP_REGULAR_EXPRESSION "\\[SKIP\\]") endforeach () endforeach () diff --git a/apps/linear_algebra/tests/CMakeLists.txt b/apps/linear_algebra/tests/CMakeLists.txt index ee2707f54db1..750a923c58e3 100644 --- a/apps/linear_algebra/tests/CMakeLists.txt +++ b/apps/linear_algebra/tests/CMakeLists.txt @@ -1,5 +1,7 @@ add_executable(test_halide_blas test_halide_blas.cpp) -target_link_libraries(test_halide_blas PRIVATE ${DEFAULT_BLAS} halide_blas) +target_link_libraries(test_halide_blas PRIVATE BLAS::BLAS halide_blas) +target_include_directories(test_halide_blas PRIVATE "${linear_algebra_SOURCE_DIR}/include") + add_test(NAME test_halide_blas COMMAND test_halide_blas) set_tests_properties(test_halide_blas PROPERTIES LABELS linear_algebra diff --git a/apps/linear_blur/linear_blur_generator.cpp b/apps/linear_blur/linear_blur_generator.cpp index 9b18e4b4bd3d..ec9db2e8097b 100644 --- a/apps/linear_blur/linear_blur_generator.cpp +++ b/apps/linear_blur/linear_blur_generator.cpp @@ -17,7 +17,7 @@ struct LinearBlur : public Halide::Generator { Func srgb = linear_to_srgb::generate(this, {blurred}); output(x, y, c) = srgb(x, y, c); - if (auto_schedule) { + if (using_autoscheduler()) { input.set_estimates({{0, 1536}, {0, 2560}, {0, 4}}); output.set_estimates({{0, 1536}, {0, 2560}, {0, 4}}); } else { diff --git a/apps/linear_blur/linear_to_srgb_generator.cpp b/apps/linear_blur/linear_to_srgb_generator.cpp index adf7b9426712..a45285e3b5a8 100644 --- a/apps/linear_blur/linear_to_srgb_generator.cpp +++ b/apps/linear_blur/linear_to_srgb_generator.cpp @@ -17,7 +17,7 @@ struct LinearTosRGB : public Halide::Generator { } void schedule() { - if (auto_schedule) { + if (using_autoscheduler()) { const int W = 1536, H = 2560, C = 4; // Wart: Input are defined with Vars we don't know. // Might be x,y but might be _0,_1. Use the args() to work around. diff --git a/apps/linear_blur/simple_blur_generator.cpp b/apps/linear_blur/simple_blur_generator.cpp index a53a3e26c426..78d23ae253cd 100644 --- a/apps/linear_blur/simple_blur_generator.cpp +++ b/apps/linear_blur/simple_blur_generator.cpp @@ -22,7 +22,7 @@ struct SimpleBlur : public Halide::Generator { } void schedule() { - if (auto_schedule) { + if (using_autoscheduler()) { const int W = 1536, H = 2560, C = 4; // Wart: Input are defined with Vars we don't know. // Might be x,y but might be _0,_1. Use the args() to work around. diff --git a/apps/linear_blur/srgb_to_linear_generator.cpp b/apps/linear_blur/srgb_to_linear_generator.cpp index b03907463c83..95cf203ada85 100644 --- a/apps/linear_blur/srgb_to_linear_generator.cpp +++ b/apps/linear_blur/srgb_to_linear_generator.cpp @@ -17,7 +17,7 @@ struct sRGBToLinear : public Halide::Generator { } void schedule() { - if (auto_schedule) { + if (using_autoscheduler()) { const int W = 1536, H = 2560, C = 4; // Wart: Input are defined with Vars we don't know. // Might be x,y but might be _0,_1. Use the args() to work around. diff --git a/apps/local_laplacian/CMakeLists.txt b/apps/local_laplacian/CMakeLists.txt index 6086e178b854..077382da3a77 100644 --- a/apps/local_laplacian/CMakeLists.txt +++ b/apps/local_laplacian/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(local_laplacian) enable_testing() diff --git a/apps/local_laplacian/Makefile b/apps/local_laplacian/Makefile index 21fa7bf74f6b..a9f57b4de81a 100644 --- a/apps/local_laplacian/Makefile +++ b/apps/local_laplacian/Makefile @@ -10,11 +10,11 @@ $(GENERATOR_BIN)/local_laplacian.generator: local_laplacian_generator.cpp $(GENE $(BIN)/%/local_laplacian.a: $(GENERATOR_BIN)/local_laplacian.generator @mkdir -p $(@D) - $^ -g local_laplacian -e $(GENERATOR_OUTPUTS) -o $(@D) -f local_laplacian target=$* auto_schedule=false + $^ -g local_laplacian -e $(GENERATOR_OUTPUTS) -o $(@D) -f local_laplacian target=$* $(BIN)/%/local_laplacian_auto_schedule.a: $(GENERATOR_BIN)/local_laplacian.generator @mkdir -p $(@D) - $^ -g local_laplacian -e $(GENERATOR_OUTPUTS) -o $(@D) -f local_laplacian_auto_schedule target=$*-no_runtime auto_schedule=true + $^ -g local_laplacian -e $(GENERATOR_OUTPUTS) -o $(@D) -f local_laplacian_auto_schedule target=$*-no_runtime autoscheduler=Mullapudi2016 $(BIN)/%/process: process.cpp $(BIN)/%/local_laplacian.a $(BIN)/%/local_laplacian_auto_schedule.a @mkdir -p $(@D) diff --git a/apps/local_laplacian/local_laplacian_generator.cpp b/apps/local_laplacian/local_laplacian_generator.cpp index b1c697a2a3b7..ee6e7dc09c57 100644 --- a/apps/local_laplacian/local_laplacian_generator.cpp +++ b/apps/local_laplacian/local_laplacian_generator.cpp @@ -98,7 +98,7 @@ class LocalLaplacian : public Halide::Generator { output.set_estimates({{0, 1536}, {0, 2560}, {0, 3}}); /* THE SCHEDULE */ - if (auto_schedule) { + if (using_autoscheduler()) { // Nothing. } else if (get_target().has_gpu_feature()) { // GPU schedule. diff --git a/apps/max_filter/CMakeLists.txt b/apps/max_filter/CMakeLists.txt index 6c0b94bab7a1..68b228438c5f 100644 --- a/apps/max_filter/CMakeLists.txt +++ b/apps/max_filter/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(max_filter) enable_testing() diff --git a/apps/max_filter/Makefile b/apps/max_filter/Makefile index bd755774b2f5..ec7fdc7e0739 100644 --- a/apps/max_filter/Makefile +++ b/apps/max_filter/Makefile @@ -12,11 +12,11 @@ $(GENERATOR_BIN)/max_filter.generator: max_filter_generator.cpp $(GENERATOR_DEPS $(BIN)/%/max_filter.a: $(GENERATOR_BIN)/max_filter.generator @mkdir -p $(@D) - $< -g max_filter -f max_filter -o $(BIN)/$* target=$*-no_runtime auto_schedule=false + $< -g max_filter -f max_filter -o $(BIN)/$* target=$*-no_runtime $(BIN)/%/max_filter_auto_schedule.a: $(GENERATOR_BIN)/max_filter.generator @mkdir -p $(@D) - $< -g max_filter -f max_filter_auto_schedule -o $(BIN)/$* target=$*-no_runtime auto_schedule=true + $< -g max_filter -f max_filter_auto_schedule -o $(BIN)/$* target=$*-no_runtime autoscheduler=Mullapudi2016 $(BIN)/%/runtime.a: $(GENERATOR_BIN)/max_filter.generator @mkdir -p $(@D) diff --git a/apps/max_filter/max_filter_generator.cpp b/apps/max_filter/max_filter_generator.cpp index 02856a5e4604..bfe0c9457e23 100644 --- a/apps/max_filter/max_filter_generator.cpp +++ b/apps/max_filter/max_filter_generator.cpp @@ -64,7 +64,7 @@ class Max : public Halide::Generator { } // Schedule - if (!auto_schedule) { + if (!using_autoscheduler()) { if (get_target().has_gpu_feature()) { // 11.8ms on a 2060 RTX diff --git a/apps/nl_means/CMakeLists.txt b/apps/nl_means/CMakeLists.txt index 651a7535cbc0..a92bedb14a3b 100644 --- a/apps/nl_means/CMakeLists.txt +++ b/apps/nl_means/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(nl_means) enable_testing() diff --git a/apps/nl_means/Makefile b/apps/nl_means/Makefile index 2c7fecdccc47..109cb5af13f7 100644 --- a/apps/nl_means/Makefile +++ b/apps/nl_means/Makefile @@ -10,11 +10,11 @@ $(GENERATOR_BIN)/nl_means.generator: nl_means_generator.cpp $(GENERATOR_DEPS) $(BIN)/%/nl_means.a: $(GENERATOR_BIN)/nl_means.generator @mkdir -p $(@D) - $^ -g nl_means -e $(GENERATOR_OUTPUTS) -o $(@D) -f nl_means target=$* auto_schedule=false + $^ -g nl_means -e $(GENERATOR_OUTPUTS) -o $(@D) -f nl_means target=$* $(BIN)/%/nl_means_auto_schedule.a: $(GENERATOR_BIN)/nl_means.generator @mkdir -p $(@D) - $^ -g nl_means -e $(GENERATOR_OUTPUTS) -o $(@D) -f nl_means_auto_schedule target=$*-no_runtime auto_schedule=true + $^ -g nl_means -e $(GENERATOR_OUTPUTS) -o $(@D) -f nl_means_auto_schedule target=$*-no_runtime autoscheduler=Mullapudi2016 $(BIN)/%/process: process.cpp $(BIN)/%/nl_means.a $(BIN)/%/nl_means_auto_schedule.a @mkdir -p $(@D) diff --git a/apps/nl_means/nl_means_generator.cpp b/apps/nl_means/nl_means_generator.cpp index ec51844119ed..5b3e136111ff 100644 --- a/apps/nl_means/nl_means_generator.cpp +++ b/apps/nl_means/nl_means_generator.cpp @@ -81,7 +81,7 @@ class NonLocalMeans : public Halide::Generator { // Provide estimates on the output pipeline non_local_means.set_estimates({{0, 1536}, {0, 2560}, {0, 3}}); - if (auto_schedule) { + if (using_autoscheduler()) { // nothing } else if (get_target().has_gpu_feature()) { // 22 ms on a 2060 RTX diff --git a/apps/onnx/model.cpp b/apps/onnx/model.cpp index b2d1738d3680..eb7327974612 100644 --- a/apps/onnx/model.cpp +++ b/apps/onnx/model.cpp @@ -344,8 +344,6 @@ std::vector run( } Halide::Realization real(outputs); Halide::Target tgt = Halide::get_host_target(); - // Don't allow LLVM to mess with the code. - tgt.set_feature(Halide::Target::DisableLLVMLoopOpt, true); // Don't create buffers larger than 2GB since we use 32bit signed indices to // index the data stored in them. tgt.set_feature(Halide::Target::LargeBuffers, false); @@ -461,8 +459,6 @@ double benchmark( Halide::Realization real(outputs); Halide::Target tgt = Halide::get_host_target(); - // Don't allow LLVM to mess with the code. - tgt.set_feature(Halide::Target::DisableLLVMLoopOpt, true); // Don't create buffers larger than 2GB since we use 32bit signed indices to // index the data stored in them. tgt.set_feature(Halide::Target::LargeBuffers, false); diff --git a/apps/random_pipeline/random_pipeline_generator.cpp b/apps/random_pipeline/random_pipeline_generator.cpp index d555dbcc9560..b3a82c97d452 100644 --- a/apps/random_pipeline/random_pipeline_generator.cpp +++ b/apps/random_pipeline/random_pipeline_generator.cpp @@ -1052,7 +1052,7 @@ class RandomPipeline : public Halide::Generator { std::cout << "Approx size: " << stages.back().w << ", " << stages.back().h << ", " << stages.back().c << "\n"; Stage next = random_stage(stages); stages.push_back(next); - if (!auto_schedule) { + if (!using_autoscheduler()) { stages.back().func.compute_root().reorder(x, c, y).vectorize(x, 8).parallel(y, 8); } } @@ -1064,11 +1064,11 @@ class RandomPipeline : public Halide::Generator { Stage casted = cast_stage(output.type(), tail); output = casted.func; - if (!auto_schedule) { + if (!using_autoscheduler()) { output.compute_root().reorder(x, c, y).vectorize(x, 8).parallel(y); } - if (auto_schedule) { + if (using_autoscheduler()) { input.dim(0).set_estimate(0, 2000) .dim(1).set_estimate(0, 2000) .dim(2).set_estimate(0, 3); diff --git a/apps/resize/CMakeLists.txt b/apps/resize/CMakeLists.txt index 68b2effd3613..1b5d14233c74 100644 --- a/apps/resize/CMakeLists.txt +++ b/apps/resize/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(resize) enable_testing() diff --git a/apps/resnet_50/Makefile b/apps/resnet_50/Makefile index 3d1dd30c9ce8..5303bd06e449 100644 --- a/apps/resnet_50/Makefile +++ b/apps/resnet_50/Makefile @@ -17,7 +17,7 @@ $(GENERATOR_BIN)/resnet50.generator: Resnet50Generator.cpp $(GENERATOR_DEPS) $(BIN)/%/resnet50.a: $(GENERATOR_BIN)/resnet50.generator @mkdir -p $(@D) - $^ -g resnet50 -o $(@D) -f resnet50 target=$* auto_schedule=false + $^ -g resnet50 -o $(@D) -f resnet50 target=$* $(BIN)/%/process: process.cpp $(BIN)/%/resnet50.a @mkdir -p $(@D) diff --git a/apps/stencil_chain/CMakeLists.txt b/apps/stencil_chain/CMakeLists.txt index d2f422604e88..ed85b1ba0c36 100644 --- a/apps/stencil_chain/CMakeLists.txt +++ b/apps/stencil_chain/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(stencil_chain) enable_testing() diff --git a/apps/stencil_chain/Makefile b/apps/stencil_chain/Makefile index 116922d03095..4c2706e66cd5 100644 --- a/apps/stencil_chain/Makefile +++ b/apps/stencil_chain/Makefile @@ -10,11 +10,11 @@ $(GENERATOR_BIN)/stencil_chain.generator: stencil_chain_generator.cpp $(GENERATO $(BIN)/%/stencil_chain.a: $(GENERATOR_BIN)/stencil_chain.generator @mkdir -p $(@D) - $^ -g stencil_chain -e $(GENERATOR_OUTPUTS) -o $(@D) -f stencil_chain target=$* auto_schedule=false + $^ -g stencil_chain -e $(GENERATOR_OUTPUTS) -o $(@D) -f stencil_chain target=$* $(BIN)/%/stencil_chain_auto_schedule.a: $(GENERATOR_BIN)/stencil_chain.generator @mkdir -p $(@D) - $^ -g stencil_chain -e $(GENERATOR_OUTPUTS) -o $(@D) -f stencil_chain_auto_schedule target=$*-no_runtime auto_schedule=true + $^ -g stencil_chain -e $(GENERATOR_OUTPUTS) -o $(@D) -f stencil_chain_auto_schedule target=$*-no_runtime autoscheduler=Mullapudi2016 $(BIN)/%/process: process.cpp $(BIN)/%/stencil_chain.a $(BIN)/%/stencil_chain_auto_schedule.a @mkdir -p $(@D) diff --git a/apps/stencil_chain/stencil_chain_generator.cpp b/apps/stencil_chain/stencil_chain_generator.cpp index ebe07d51bdba..f62f269d6146 100644 --- a/apps/stencil_chain/stencil_chain_generator.cpp +++ b/apps/stencil_chain/stencil_chain_generator.cpp @@ -45,7 +45,7 @@ class StencilChain : public Halide::Generator { output.set_estimates({{0, width}, {0, height}}); } - if (auto_schedule) { + if (using_autoscheduler()) { // nothing } else if (get_target().has_gpu_feature()) { // GPU schedule diff --git a/apps/support/Makefile.inc b/apps/support/Makefile.inc index eb14ad405c05..dfb225e3bf9e 100644 --- a/apps/support/Makefile.inc +++ b/apps/support/Makefile.inc @@ -13,11 +13,6 @@ IMAGES ?= ../images UNAME ?= $(shell uname) SHELL = bash PYTHON ?= python3 - -# TODO(srj): the python bindings need to be put into the distrib folders; -# this is a hopefully-temporary workaround (https://github.com/halide/Halide/issues/4368) -HALIDE_PYTHON_BINDINGS_PATH ?= $(realpath ../../bin/python3_bindings) - BIN_DIR ?= bin # Most build outputs go into $(BIN)/$(HL_TARGET)/$(HL_TARGET)/, so that you can vary the test @@ -103,6 +98,11 @@ else SHARED_EXT=so endif +# We want to build Halide plugins as .so on all posixy systems, including OSX. +# This is called out as a named var to make it clear that the use +# is deliberate, not an accident. +PLUGIN_EXT=so + # We expect $ANDROID_NDK_ROOT to be defined by an env var. # We require at least NDK r19b or later. ANDROID_NDK_ROOT ?= /path/to/android_ndk_root @@ -180,9 +180,9 @@ LIBHALIDE_LDFLAGS_STATIC ?= $(LIB_HALIDE_STATIC) $(LDFLAGS) # Autoschedulers. Mullapudi2016 is currently the default, because it's fast. AUTOSCHEDULER ?= mcts ifneq ($(AUTOSCHEDULER),) -LIB_AUTOSCHEDULER ?= $(HALIDE_DISTRIB_PATH)/lib/libautoschedule_$(AUTOSCHEDULER).$(SHARED_EXT) +LIB_AUTOSCHEDULER ?= $(HALIDE_DISTRIB_PATH)/lib/libautoschedule_$(AUTOSCHEDULER).$(PLUGIN_EXT) ifeq ($(UNAME), Darwin) -LIBHALIDE_LDFLAGS += -Wl,-force_load $(HALIDE_DISTRIB_PATH)/lib/libautoschedule_$(AUTOSCHEDULER).$(SHARED_EXT) +LIBHALIDE_LDFLAGS += -Wl,-force_load $(HALIDE_DISTRIB_PATH)/lib/libautoschedule_$(AUTOSCHEDULER).$(PLUGIN_EXT) else LIBHALIDE_LDFLAGS += -Wl,--no-as-needed -lautoschedule_$(AUTOSCHEDULER) -Wl,--as-needed endif diff --git a/apps/support/autoscheduler.inc b/apps/support/autoscheduler.inc deleted file mode 100644 index fc3aeb8f1876..000000000000 --- a/apps/support/autoscheduler.inc +++ /dev/null @@ -1,99 +0,0 @@ -ifndef BIN -$(error BIN must be set prior to including autoscheduler.inc) -endif - -AUTOSCHED_SRC ?= $(realpath ../autoscheduler) - -# Default to $(BIN) so that the toplevel Makefile can put all build products -# into the build products directory (rather than into the source tree) -AUTOSCHED_BIN ?= $(BIN) -AUTOSCHED_SAMPLES_OUT ?= $(AUTOSCHED_SRC)/samples - -AUTOSCHED_WEIGHT_OBJECTS=$(AUTOSCHED_BIN)/baseline_weights.o - -# TODO(srj): depending on something not in the distrib folder isn't strictly -# kosher, but this is still experimental -$(AUTOSCHED_BIN)/binary2cpp: ../../tools/binary2cpp.cpp - @mkdir -p $(@D) - $(CXX) $< -o $@ - -$(AUTOSCHED_BIN)/baseline_weights.cpp: $(AUTOSCHED_BIN)/binary2cpp $(AUTOSCHED_SRC)/baseline.weights - @mkdir -p $(@D) - $(AUTOSCHED_BIN)/binary2cpp baseline_weights < $(AUTOSCHED_SRC)/baseline.weights > $@ - -$(AUTOSCHED_BIN)/baseline_weights.o: $(AUTOSCHED_BIN)/baseline_weights.cpp - $(CXX) -c $< -o $@ - -AUTOSCHED_COST_MODEL_LIBS=\ -$(AUTOSCHED_BIN)/cost_model/cost_model.a \ -$(AUTOSCHED_BIN)/cost_model/train_cost_model.a \ - -$(AUTOSCHED_BIN)/cost_model.generator: $(AUTOSCHED_SRC)/cost_model_generator.cpp \ - $(AUTOSCHED_SRC)/cost_model_schedule.h \ - $(AUTOSCHED_SRC)/NetworkSize.h \ - $(GENERATOR_DEPS) - @mkdir -p $(@D) - $(CXX) $(CXXFLAGS) $(filter %.cpp,$^) -o $@ $(LIBHALIDE_LDFLAGS) $(USE_EXPORT_DYNAMIC) - -$(AUTOSCHED_BIN)/auto_schedule_runtime.a: $(AUTOSCHED_BIN)/cost_model.generator - @mkdir -p $(@D) - $^ -r auto_schedule_runtime -o $(AUTOSCHED_BIN) target=$(HL_TARGET) - -$(AUTOSCHED_BIN)/cost_model/%.a: $(AUTOSCHED_BIN)/cost_model.generator - @mkdir -p $(@D) - $^ -g $* -o $(AUTOSCHED_BIN)/cost_model -f $* target=$(HL_TARGET)-no_runtime auto_schedule=false -e stmt,static_library,h,assembly - -# It's important to use dynamic lookups for undefined symbols here: all of libHalide -# is expected to be present (in the loading binary), so we explicitly make the symbols -# undefined rather than dependent on libHalide.so. -$(AUTOSCHED_BIN)/libauto_schedule.so: $(AUTOSCHED_SRC)/AutoSchedule.cpp \ - $(AUTOSCHED_SRC)/ASLog.cpp \ - $(AUTOSCHED_SRC)/DefaultCostModel.h \ - $(AUTOSCHED_SRC)/DefaultCostModel.cpp \ - $(AUTOSCHED_SRC)/Weights.h \ - $(AUTOSCHED_SRC)/Weights.cpp \ - $(AUTOSCHED_SRC)/FunctionDAG.h \ - $(AUTOSCHED_SRC)/FunctionDAG.cpp \ - $(AUTOSCHED_SRC)/LoopNest.h \ - $(AUTOSCHED_SRC)/LoopNest.cpp \ - $(AUTOSCHED_SRC)/Featurization.h \ - $(AUTOSCHED_SRC)/CostModel.h \ - $(AUTOSCHED_SRC)/PerfectHashMap.h \ - $(AUTOSCHED_WEIGHT_OBJECTS) \ - $(AUTOSCHED_COST_MODEL_LIBS) \ - $(GENERATOR_DEPS) \ - $(AUTOSCHED_BIN)/auto_schedule_runtime.a - @mkdir -p $(@D) - $(CXX) -shared $(USE_EXPORT_DYNAMIC) -fPIC -fvisibility=hidden -fvisibility-inlines-hidden $(CXXFLAGS) $(OPTIMIZE) -I $(AUTOSCHED_BIN)/cost_model $(filter-out %.h $(LIBHALIDE_LDFLAGS),$^) -o $@ $(HALIDE_SYSTEM_LIBS) - -$(AUTOSCHED_BIN)/retrain_cost_model: $(AUTOSCHED_SRC)/retrain_cost_model.cpp \ - $(AUTOSCHED_SRC)/ASLog.cpp \ - $(AUTOSCHED_SRC)/DefaultCostModel.h \ - $(AUTOSCHED_SRC)/DefaultCostModel.cpp \ - $(AUTOSCHED_SRC)/Weights.h \ - $(AUTOSCHED_SRC)/Weights.cpp \ - $(AUTOSCHED_SRC)/CostModel.h \ - $(AUTOSCHED_SRC)/NetworkSize.h \ - $(AUTOSCHED_COST_MODEL_LIBS) \ - $(AUTOSCHED_WEIGHT_OBJECTS) \ - $(AUTOSCHED_BIN)/auto_schedule_runtime.a - @mkdir -p $(@D) - $(CXX) $(CXXFLAGS) -frtti -Wall -I ../support -I $(AUTOSCHED_BIN)/cost_model $(OPTIMIZE) $(filter-out %.h,$^) -o $@ $(LIBHALIDE_LDFLAGS) $(USE_OPEN_MP) - -$(AUTOSCHED_BIN)/featurization_to_sample: $(AUTOSCHED_SRC)/featurization_to_sample.cpp - @mkdir -p $(@D) - $(CXX) $(CXXFLAGS) $< $(OPTIMIZE) -o $@ - -$(AUTOSCHED_BIN)/get_host_target: $(AUTOSCHED_SRC)/get_host_target.cpp $(LIB_HALIDE) $(HALIDE_DISTRIB_PATH)/include/Halide.h - @mkdir -p $(@D) - $(CXX) $(CXXFLAGS) $(filter %.cpp,$^) $(LIBHALIDE_LDFLAGS) $(OPTIMIZE) -o $@ - -$(AUTOSCHED_BIN)/weightsdir_to_weightsfile: $(AUTOSCHED_SRC)/weightsdir_to_weightsfile.cpp $(AUTOSCHED_SRC)/Weights.cpp - @mkdir -p $(@D) - $(CXX) $(CXXFLAGS) $^ $(OPTIMIZE) -o $@ - -# This is the value that machine_params defaults to if no custom value is specified; -# see MachineParams::generic() -HL_MACHINE_PARAMS ?= 32,25165824,160 - - diff --git a/apps/unsharp/CMakeLists.txt b/apps/unsharp/CMakeLists.txt index 7aed8269bbd0..443cf92fb0a3 100644 --- a/apps/unsharp/CMakeLists.txt +++ b/apps/unsharp/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(unsharp) enable_testing() diff --git a/apps/unsharp/Makefile b/apps/unsharp/Makefile index fa912ad172e1..047fc2854fb3 100644 --- a/apps/unsharp/Makefile +++ b/apps/unsharp/Makefile @@ -10,11 +10,11 @@ $(GENERATOR_BIN)/unsharp.generator: unsharp_generator.cpp $(GENERATOR_DEPS) $(BIN)/%/unsharp.a: $(GENERATOR_BIN)/unsharp.generator @mkdir -p $(@D) - $< -g unsharp -f unsharp -o $(BIN)/$* target=$*-no_runtime auto_schedule=false + $< -g unsharp -f unsharp -o $(BIN)/$* target=$*-no_runtime $(BIN)/%/unsharp_auto_schedule.a: $(GENERATOR_BIN)/unsharp.generator @mkdir -p $(@D) - $< -g unsharp -f unsharp_auto_schedule -o $(BIN)/$* target=$*-no_runtime auto_schedule=true + $< -g unsharp -f unsharp_auto_schedule -o $(BIN)/$* target=$*-no_runtime autoscheduler=Mullapudi2016 $(BIN)/%/runtime.a: $(GENERATOR_BIN)/unsharp.generator @mkdir -p $(@D) diff --git a/apps/unsharp/unsharp_generator.cpp b/apps/unsharp/unsharp_generator.cpp index d68702bf1e20..c1070b2753fe 100644 --- a/apps/unsharp/unsharp_generator.cpp +++ b/apps/unsharp/unsharp_generator.cpp @@ -61,7 +61,7 @@ class Unsharp : public Halide::Generator { } // Schedule - if (!auto_schedule) { + if (!using_autoscheduler()) { // Some Intel Mac Minis have GPUs that require tile sizes smaller than 32x32 // for this pipeline because they have too few registers. Drop to 16x16 to // avoid unexpected crashes in CI. diff --git a/apps/wavelet/CMakeLists.txt b/apps/wavelet/CMakeLists.txt index 3136e819fec7..bd142bcf07c4 100644 --- a/apps/wavelet/CMakeLists.txt +++ b/apps/wavelet/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(wavelet) enable_testing() diff --git a/cmake/AddCudaToTarget.cmake b/cmake/AddCudaToTarget.cmake deleted file mode 100644 index e475f1da4c2f..000000000000 --- a/cmake/AddCudaToTarget.cmake +++ /dev/null @@ -1,41 +0,0 @@ -function(add_cuda_to_target TARGET VISIBILITY) - if (TARGET CUDA::cuda_driver AND TARGET CUDA::cudart) - target_link_libraries(${TARGET} ${VISIBILITY} CUDA::cuda_driver CUDA::cudart) - return() - endif () - - find_package(CUDAToolkit QUIET) - if (TARGET CUDA::cuda_driver AND TARGET CUDA::cudart) - target_link_libraries(${TARGET} ${VISIBILITY} CUDA::cuda_driver CUDA::cudart) - return() - endif () - - # Find the package for the CUDA_TOOLKIT_ROOT_DIR hint. - find_package(CUDA QUIET) - if (NOT CUDA_FOUND) - set(CUDA_TOOLKIT_ROOT_DIR) - endif () - - # Find the CUDA driver library by doing what the CUDAToolkit module from - # CMake 3.17 does. - find_library(CUDA_DRIVER_LIBRARY - NAMES cuda_driver cuda - HINTS ${CUDA_TOOLKIT_ROOT_DIR} ENV CUDA_PATH - PATH_SUFFIXES nvidia/current lib64 lib/x64 lib) - if (NOT CUDA_DRIVER_LIBRARY) - # Don't try any stub directories until we have exhausted all other search locations. - find_library(CUDA_DRIVER_LIBRARY - NAMES cuda_driver cuda - HINTS ${CUDA_TOOLKIT_ROOT_DIR} ENV CUDA_PATH - PATH_SUFFIXES lib64/stubs lib/x64/stubs lib/stubs stubs) - endif () - mark_as_advanced(CUDA_DRIVER_LIBRARY) - - if (NOT CUDA_DRIVER_LIBRARY) - message(WARNING "CUDA driver library not found on system.") - return() - endif () - - target_include_directories(${TARGET} ${VISIBILITY} ${CUDA_INCLUDE_DIRS}) - target_link_libraries(${TARGET} ${VISIBILITY} ${CUDA_LIBRARIES} ${CUDA_DRIVER_LIBRARY}) -endfunction() diff --git a/cmake/BundleStatic.cmake b/cmake/BundleStatic.cmake index 0c1d562dcf79..023db8edd1f2 100644 --- a/cmake/BundleStatic.cmake +++ b/cmake/BundleStatic.cmake @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) ## # This module provides a utility for bundling a set of IMPORTED @@ -22,7 +22,7 @@ cmake_minimum_required(VERSION 3.16) ## # All of the IMPORTED_ and INTERFACE_ properties should be accounted for below. -# https://cmake.org/cmake/help/v3.16/manual/cmake-properties.7.html#properties-on-targets +# https://cmake.org/cmake/help/v3.22/manual/cmake-properties.7.html#properties-on-targets # Irrelevant properties: # IMPORTED_IMPLIB(_) # shared-only @@ -149,7 +149,7 @@ function(transfer_locations) get_property(lib TARGET ${ARG_FROM} PROPERTY "IMPORTED_LOCATION${cfg}") if (lib) - get_filename_component(stage "${lib}" NAME_WE) + cmake_path(GET lib STEM stage) set(stage "${CMAKE_CURRENT_BINARY_DIR}/${stage}.obj") if (NOT EXISTS "${stage}") @@ -184,7 +184,9 @@ function(transfer_locations) set(globs "") foreach (lang IN LISTS languages) - list(APPEND globs "${stage}/*${CMAKE_${lang}_OUTPUT_EXTENSION}") + if (DEFINED "CMAKE_${lang}_OUTPUT_EXTENSION") + list(APPEND globs "${stage}/*${CMAKE_${lang}_OUTPUT_EXTENSION}") + endif () endforeach () file(GLOB_RECURSE objects ${globs}) diff --git a/cmake/FindHalide.cmake b/cmake/FindHalide.cmake new file mode 100644 index 000000000000..cffa7dd90742 --- /dev/null +++ b/cmake/FindHalide.cmake @@ -0,0 +1,6 @@ +# This file should NOT be installed. +# It is used by python_bindings (and future externalizable projects) to satisfy +# calls to `find_package(Halide)` when used in-tree. + +message(VERBOSE "Spoofing find_package(Halide) since in-tree builds already have Halide available.") +set(Halide_FOUND 1) diff --git a/cmake/HalideGeneratorHelpers.cmake b/cmake/HalideGeneratorHelpers.cmake index 884b49fd8068..6f7163c4b2ab 100644 --- a/cmake/HalideGeneratorHelpers.cmake +++ b/cmake/HalideGeneratorHelpers.cmake @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) include(${CMAKE_CURRENT_LIST_DIR}/HalideTargetHelpers.cmake) @@ -80,10 +80,11 @@ function(_Halide_try_load_generators) # Communicate found information to the caller set(${ARG_PACKAGE_NAME}_FOUND "${${ARG_PACKAGE_NAME}_FOUND}" PARENT_SCOPE) - if (NOT ${ARG_PACKAGE_NAME}_FOUND AND CMAKE_CROSSCOMPILING) + if (NOT ${ARG_PACKAGE_NAME}_FOUND AND CMAKE_CROSSCOMPILING AND NOT CMAKE_CROSSCOMPILING_EMULATOR) message(WARNING - "${ARG_PACKAGE_NAME} were not found and it looks like you are cross-compiling. " - "This is likely to fail. Please set -D${ARG_PACKAGE_NAME}_ROOT=... at the CMake " + "'${ARG_PACKAGE_NAME}' was not found and it looks like you " + "are cross-compiling without an emulator. This is likely to " + "fail. Please set -D${ARG_PACKAGE_NAME}_ROOT=... at the CMake " "command line to the build directory of a host-built ${PROJECT_NAME}.") endif () endif () @@ -284,7 +285,6 @@ function(add_halide_library TARGET) # Attach an autoscheduler if the user requested it ## - set(autoscheduler "") if (ARG_AUTOSCHEDULER) if ("${ARG_AUTOSCHEDULER}" MATCHES "::") if (NOT TARGET "${ARG_AUTOSCHEDULER}") @@ -298,8 +298,7 @@ function(add_halide_library TARGET) elseif (NOT ARG_PLUGINS) message(AUTHOR_WARNING "AUTOSCHEDULER set to a scheduler name but no plugins were loaded") endif () - set(autoscheduler -s "${ARG_AUTOSCHEDULER}") - list(PREPEND ARG_PARAMS auto_schedule=true) + list(PREPEND ARG_PARAMS "autoscheduler=${ARG_AUTOSCHEDULER}") endif () ## @@ -324,7 +323,9 @@ function(add_halide_library TARGET) foreach (p IN LISTS ARG_PLUGINS) list(APPEND generator_plugins "$") endforeach () - set(generator_plugins -p "$>") + # $ gets confused about quoting. Just use list(JOIN) here instead. + list(JOIN generator_plugins $ generator_plugins_list) + set(generator_plugins -p ${generator_plugins_list}) endif () add_custom_command(OUTPUT ${generator_output_files} @@ -335,7 +336,6 @@ function(add_halide_library TARGET) -f "${ARG_FUNCTION_NAME}" -e "$>" ${generator_plugins} - ${autoscheduler} -o . "target=$>" ${ARG_PARAMS} @@ -492,19 +492,9 @@ function(_Halide_target_link_gpu_libs TARGET VISIBILITY) endif () if ("${ARGN}" MATCHES "metal") - find_library(METAL_LIBRARY Metal) - if (NOT METAL_LIBRARY) - message(AUTHOR_WARNING "Metal framework dependency not found on system.") - else () - target_link_libraries(${TARGET} ${VISIBILITY} "${METAL_LIBRARY}") - endif () - - find_library(FOUNDATION_LIBRARY Foundation) - if (NOT FOUNDATION_LIBRARY) - message(AUTHOR_WARNING "Foundation framework dependency not found on system.") - else () - target_link_libraries(${TARGET} ${VISIBILITY} "${FOUNDATION_LIBRARY}") - endif () + find_library(FOUNDATION_LIBRARY Foundation REQUIRED) + find_library(METAL_LIBRARY Metal REQUIRED) + target_link_libraries(${TARGET} ${VISIBILITY} "${FOUNDATION_LIBRARY}" "${METAL_LIBRARY}") endif () endfunction() diff --git a/cmake/HalideTargetHelpers.cmake b/cmake/HalideTargetHelpers.cmake index a7235d974a9b..9edb5cfd8fdd 100644 --- a/cmake/HalideTargetHelpers.cmake +++ b/cmake/HalideTargetHelpers.cmake @@ -1,3 +1,5 @@ +cmake_minimum_required(VERSION 3.22) + ## # Utilities for manipulating Halide target triples ## @@ -27,22 +29,16 @@ function(_Halide_cmake_target OUTVAR) set(${OUTVAR} "${arch}-${bits}-${os}" PARENT_SCOPE) endfunction() -function(_Halide_cache var val doc) - if (DEFINED ${var}) - set(${var} "${${var}}" CACHE STRING "${doc}") - else () - set(${var} "${val}" CACHE STRING "${doc}") - endif () -endfunction() - ## # Set Halide `host` and `cmake` meta-target values ## _Halide_cmake_target(_active_triple) -_Halide_cache(Halide_HOST_TARGET "${_active_triple}" "Halide target triple matching the Halide library") -_Halide_cache(Halide_CMAKE_TARGET "${_active_triple}" "Halide target triple matching the CMake target") +set(Halide_HOST_TARGET "${_active_triple}" + CACHE STRING "Halide target triple matching the Halide library") +set(Halide_CMAKE_TARGET "${_active_triple}" + CACHE STRING "Halide target triple matching the CMake target") unset(_active_triple) @@ -58,7 +54,8 @@ else () set(_default_target "${Halide_CMAKE_TARGET}") endif () -_Halide_cache(Halide_TARGET "${_default_target}" "The default target to use when AOT compiling") +set(Halide_TARGET "${_default_target}" + CACHE STRING "The default target to use when AOT compiling") unset(_default_target) diff --git a/cmake/HalideTestHelpers.cmake b/cmake/HalideTestHelpers.cmake index 92d182ca2e03..50b072f2814a 100644 --- a/cmake/HalideTestHelpers.cmake +++ b/cmake/HalideTestHelpers.cmake @@ -56,6 +56,10 @@ function(add_halide_test TARGET) SKIP_REGULAR_EXPRESSION "\\[SKIP\\]" WILL_FAIL ${args_EXPECT_FAILURE}) + set_target_properties(${TARGET} PROPERTIES + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN TRUE) + # Add a meta-target for each group, to allow us to build by group easily foreach (GROUP IN LISTS args_GROUPS) set(META_TARGET build_${GROUP}) @@ -77,7 +81,7 @@ function(tests) set(TEST_NAMES "") foreach (file IN LISTS args_SOURCES) - get_filename_component(name "${file}" NAME_WE) + cmake_path(GET file STEM name) set(TARGET "${PRIMARY_GROUP}_${name}") list(APPEND TEST_NAMES "${TARGET}") diff --git a/cmake/MakeShellPath.cmake b/cmake/MakeShellPath.cmake deleted file mode 100644 index 37596801c64a..000000000000 --- a/cmake/MakeShellPath.cmake +++ /dev/null @@ -1,15 +0,0 @@ -## -# Convenience function for creating shell paths -## - -function(make_shell_path OUTVAR) - if (WIN32) - set(SEP "\\$") - else () - set(SEP ":") - endif () - - list(TRANSFORM ARGN REPLACE "^(.+)$" "$") - string(REPLACE ";" "${SEP}" ARGN "${ARGN}") - set(${OUTVAR} "${ARGN}" PARENT_SCOPE) -endfunction() diff --git a/cmake/PythonExtensionHelpers.cmake b/cmake/PythonExtensionHelpers.cmake new file mode 100644 index 000000000000..1433f6557a73 --- /dev/null +++ b/cmake/PythonExtensionHelpers.cmake @@ -0,0 +1,123 @@ +include(HalideGeneratorHelpers) +include(TargetExportScript) + +set(_STUB_DIR "${Halide_SOURCE_DIR}/python_bindings/stub") + +# There are two sorts of Python Extensions that we can produce for a Halide Generator +# written in C++: +# +# - One that is essentially the 'native code' output of a Generator, wrapped with enough CPython +# glue code to make it callable from Python. This is analogous to the usual Generator output +# when building a C++ codebase, and is the usual mode used for distribution of final product; +# these correspond to 'ahead-of-time' (AOT) code generation. The resulting code has no dependency +# on libHalide. We'll refer to this sort of extension as an "AOT extension". +# +# - One that essentially *the Generator itself*, wrapped in CPython glue code to make it callable +# from Python at Halide compilation time. This is analogous to the (rarely used) GeneratorStub +# code that can be used to compose multiple Generators together. The resulting extension *does* +# depend on libHalide, and can be used in either JIT or AOT mode for compilation. +# We'll refer to this sort of extension as a "Stub extension". +# +# For testing purposes here, we don't bother using distutils/setuptools to produce a properly-packaged +# Python extension; rather, we simply produce a .so file with the correct name exported, and ensure +# it's in the PYTHONPATH when testing. +# +# In our build files here, we build both kinds of extension for every Generator in the generators/ +# directory (even though not all are used). As a simplistic way to distinguish between the two +# sorts of extensions, we use the unadorned Generator name for AOT extensions, and the Generator name +# suffixed with "_stub" for Stub extensions. (TODO: this is unsatisfyingly hackish; better suggestions +# would be welcome.) + +function(target_export_single_symbol TARGET SYMBOL) + configure_file("${_STUB_DIR}/ext.ldscript.apple.in" "${CMAKE_CURRENT_BINARY_DIR}/${TARGET}.ldscript.apple") + configure_file("${_STUB_DIR}/ext.ldscript.linux.in" "${CMAKE_CURRENT_BINARY_DIR}/${TARGET}.ldscript") + target_export_script( + ${TARGET} + APPLE_LD "${CMAKE_CURRENT_BINARY_DIR}/${TARGET}.ldscript.apple" + GNU_LD "${CMAKE_CURRENT_BINARY_DIR}/${TARGET}.ldscript" + ) +endfunction() + +function(add_python_aot_extension TARGET) + set(options) + set(oneValueArgs GENERATOR FUNCTION_NAME) + set(multiValueArgs SOURCES LINK_LIBRARIES FEATURES PARAMS) + cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + if (NOT ARG_GENERATOR) + set(ARG_GENERATOR "${TARGET}") + endif () + + if (NOT ARG_FUNCTION_NAME) + set(ARG_FUNCTION_NAME "${ARG_GENERATOR}") + endif () + + # Create the Halide generator executable. + add_executable(${TARGET}.generator ${ARG_SOURCES}) + target_link_libraries(${TARGET}.generator PRIVATE Halide::Generator ${ARG_LINK_LIBRARIES}) + + # TODO: this should work (and would be preferred to the code above) + # but CMake fails with "targets not yet defined"; investigate. + # add_halide_generator(${TARGET}.generator + # SOURCES ${ARG_SOURCES}) + + # Run the Generator to produce a static library of AOT code, + # plus the 'python_extension' code necessary to produce a useful + # AOT Extention for Python: + add_halide_library(aot_${TARGET} + FROM ${TARGET}.generator + GENERATOR ${ARG_GENERATOR} + FUNCTION_NAME ${ARG_FUNCTION_NAME} + PYTHON_EXTENSION ${TARGET}.py.cpp + FEATURES ${ARG_FEATURES} + PARAMS ${ARG_PARAMS} + TARGETS cmake) + + # Take the native-code output of the Generator, add the Python-Extension + # code (to make it callable from Python), and build it into the AOT Extension we need. + if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.17) + # Add soabi info (like cpython-310-x86_64-linux-gnu) + # when CMake is new enough to know how to do it. + set(abi_flags WITH_SOABI) + else () + set(abi_flags "") + endif () + + Python3_add_library(${TARGET} MODULE ${abi_flags} ${${TARGET}.py.cpp}) + target_link_libraries(${TARGET} PRIVATE aot_${TARGET}) + set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${ARG_FUNCTION_NAME}) + target_export_single_symbol(${TARGET} ${ARG_FUNCTION_NAME}) +endfunction() + +function(add_python_stub_extension TARGET) + set(options) + set(oneValueArgs GENERATOR MODULE) + set(multiValueArgs SOURCES LINK_LIBRARIES) + cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + if (NOT ARG_GENERATOR) + set(ARG_GENERATOR "${TARGET}") + endif () + + if (NOT ARG_MODULE) + set(ARG_MODULE "${TARGET}_stub") + endif () + + # Produce a Stub Extension for the same Generator: + # Compiling PyStub.cpp, then linking with the generator's .o file, PyStubImpl.o, + # plus the same libHalide being used by halide.so. + # + # Note that we set HALIDE_PYSTUB_MODULE_NAME to $*_stub (e.g. foo_stub) but + # set HALIDE_PYSTUB_GENERATOR_NAME to the unadorned name of the Generator. + Python3_add_library(${TARGET} MODULE ${_STUB_DIR}/PyStub.cpp ${ARG_SOURCES}) + set_target_properties(${TARGET} PROPERTIES + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN ON + POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(${TARGET} PRIVATE + "HALIDE_PYSTUB_GENERATOR_NAME=${ARG_GENERATOR}" + "HALIDE_PYSTUB_MODULE_NAME=${ARG_MODULE}") + target_link_libraries(${TARGET} PRIVATE Halide::PyStubs ${ARG_LINK_LIBRARIES}) + set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${ARG_MODULE}) + target_export_single_symbol(${TARGET} ${ARG_MODULE}) +endfunction() diff --git a/cmake/TargetExportScript.cmake b/cmake/TargetExportScript.cmake index cbb980baa129..36e53558aff6 100644 --- a/cmake/TargetExportScript.cmake +++ b/cmake/TargetExportScript.cmake @@ -1,7 +1,4 @@ -# Note: in CMake 3.18+ there is a CheckLinkerFlags module that should be used to replace this. -# Sadly, CMake does not attempt to detect the underlying linker and people can try to use, eg. -# gold or lld via CMAKE_CXX_FLAGS. -include(CheckCXXSourceCompiles) +include(CheckLinkerFlag) function(target_export_script TARGET) set(options) @@ -10,7 +7,7 @@ function(target_export_script TARGET) cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) get_property(target_type TARGET ${TARGET} PROPERTY TYPE) - if (NOT target_type STREQUAL "SHARED_LIBRARY") + if (NOT target_type MATCHES "(SHARED|MODULE)_LIBRARY") # Linker scripts do nothing on non-shared libraries. return() endif () @@ -21,16 +18,8 @@ function(target_export_script TARGET) return() endif () - set(dummy_source [[ int main() { return 0; } ]]) - - # CMake doesn't recognize MSVC/ldd link.exe's unknown-option warnings - set(extra_errors FAIL_REGEX "LNK4044: unrecognized option|warning : ignoring unknown argument") - ## More linkers support the GNU syntax (ld, lld, gold), so try it first. - set(version_script "LINKER:--version-script=${ARG_GNU_LD}") - - set(CMAKE_REQUIRED_LINK_OPTIONS "${version_script}") - check_cxx_source_compiles("${dummy_source}" LINKER_HAS_FLAG_VERSION_SCRIPT ${extra_errors}) + check_linker_flag(CXX "LINKER:--version-script=${ARG_GNU_LD}" LINKER_HAS_FLAG_VERSION_SCRIPT) if (LINKER_HAS_FLAG_VERSION_SCRIPT) target_link_options(${TARGET} PRIVATE "${version_script}") @@ -39,10 +28,7 @@ function(target_export_script TARGET) endif () ## The Apple linker expects a different flag. - set(exported_symbols_list "LINKER:-exported_symbols_list,${ARG_APPLE_LD}") - - set(CMAKE_REQUIRED_LINK_OPTIONS "${exported_symbols_list}") - check_cxx_source_compiles("${dummy_source}" LINKER_HAS_FLAG_EXPORTED_SYMBOLS_LIST ${extra_errors}) + check_linker_flag(CXX "LINKER:-exported_symbols_list,${ARG_APPLE_LD}" LINKER_HAS_FLAG_EXPORTED_SYMBOLS_LIST) if (LINKER_HAS_FLAG_EXPORTED_SYMBOLS_LIST) target_link_options(${TARGET} PRIVATE "${exported_symbols_list}") diff --git a/cmake/toolchain.linux-x64-asan.cmake b/cmake/toolchain.linux-x64-asan.cmake new file mode 100644 index 000000000000..b582b5791715 --- /dev/null +++ b/cmake/toolchain.linux-x64-asan.cmake @@ -0,0 +1,32 @@ +# Toolchain for compiling with ASAN enabled on a Linux-x86-64 host. +# This is done as a "crosscompile" because we must use our LLVM version +# of clang for *all* compilation (rather than using it just for bitcode +# and letting the host compiler, eg gcc, handle everything else); ASAN +# essentially requires everything to be compiled with matching versions +# of the same compiler. +# +# Note: requires LLVM to be built with -DLLVM_ENABLE_RUNTIMES="compiler-rt;libcxx;libcxxabi;libunwind" +# +# Note: only tested with LLVM/Clang 16 as of this comment. Earlier versions +# may likely work but are not tested. + +set(CMAKE_SYSTEM_NAME Linux) +set(CMAKE_SYSTEM_PROCESSOR i686) + +set(CMAKE_C_COMPILER ${LLVM_ROOT}/bin/clang) +set(CMAKE_CXX_COMPILER ${LLVM_ROOT}/bin/clang++) +set(CMAKE_LINKER ${LLVM_ROOT}/bin/ld.lld) + +set(CMAKE_C_FLAGS_INIT "-fsanitize=address") +set(CMAKE_CXX_FLAGS_INIT "-fsanitize=address") + +set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=${CMAKE_LINKER}") +set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=${CMAKE_LINKER}") +set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=${CMAKE_LINKER}") + +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) + +set(CMAKE_CROSSCOMPILING_EMULATOR /usr/bin/env) diff --git a/dependencies/CMakeLists.txt b/dependencies/CMakeLists.txt index 002afd0bcd7d..ea4c5a173493 100644 --- a/dependencies/CMakeLists.txt +++ b/dependencies/CMakeLists.txt @@ -26,5 +26,9 @@ add_subdirectory(llvm) add_subdirectory(jpeg) add_subdirectory(png) +if (TARGET_SPIRV) + add_subdirectory(spirv) +endif() + # Needs cache vars set by llvm, do not reorder. add_subdirectory(wasm) diff --git a/dependencies/llvm/CMakeLists.txt b/dependencies/llvm/CMakeLists.txt index 248aaf9ea4e3..7c1f26d6b72f 100644 --- a/dependencies/llvm/CMakeLists.txt +++ b/dependencies/llvm/CMakeLists.txt @@ -20,12 +20,12 @@ message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") message(STATUS "Using ClangConfig.cmake in: ${Clang_DIR}") -if (LLVM_PACKAGE_VERSION VERSION_LESS 12.0) - message(FATAL_ERROR "LLVM version must be 12.0 or newer") +if (LLVM_PACKAGE_VERSION VERSION_LESS 13.0) + message(FATAL_ERROR "LLVM version must be 13.0 or newer") endif () -if (LLVM_PACKAGE_VERSION VERSION_GREATER 15.0) - message(WARNING "Halide is not tested on LLVM versions beyond 15.0") +if (LLVM_PACKAGE_VERSION VERSION_GREATER 16.0) + message(WARNING "Halide is not tested on LLVM versions beyond 16.0") endif () set(Halide_LLVM_DEFS ${LLVM_DEFINITIONS} $) diff --git a/dependencies/spirv/CMakeLists.txt b/dependencies/spirv/CMakeLists.txt new file mode 100644 index 000000000000..060b07e906df --- /dev/null +++ b/dependencies/spirv/CMakeLists.txt @@ -0,0 +1,3 @@ + +set(SPIRV_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/include" CACHE INTERNAL "Location of SPIR-V include directory") +message(STATUS "Using SPIR-V headers from: ${SPIRV_INCLUDE_DIR}...") diff --git a/dependencies/spirv/LICENSE.txt b/dependencies/spirv/LICENSE.txt new file mode 100644 index 000000000000..47974f8ce39c --- /dev/null +++ b/dependencies/spirv/LICENSE.txt @@ -0,0 +1,25 @@ +Copyright (c) 2015-2018 The Khronos Group Inc. + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and/or associated documentation files (the +"Materials"), to deal in the Materials without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Materials, and to +permit persons to whom the Materials are furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be included +in all copies or substantial portions of the Materials. + +MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS +KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS +SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT + https://www.khronos.org/registry/ + +THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. diff --git a/dependencies/spirv/README.md b/dependencies/spirv/README.md new file mode 100644 index 000000000000..3e8c2531473e --- /dev/null +++ b/dependencies/spirv/README.md @@ -0,0 +1,38 @@ +# SPIR-V Headers + +This folder contains a copy of the officially released v1.0 ANSI-C header +file for [SPIR-V](https://www.khronos.org/registry/spir-v/), obtained from +the [https://github.com/KhronosGroup/SPIRV-Headers](https://github.com/KhronosGroup/SPIRV-Headers). + +The directory structure within this folder matches that of the official +versioned include path. + +## License + +``` +Copyright (c) 2015-2018 The Khronos Group Inc. + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and/or associated documentation files (the +"Materials"), to deal in the Materials without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Materials, and to +permit persons to whom the Materials are furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be included +in all copies or substantial portions of the Materials. + +MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS +KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS +SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT + https://www.khronos.org/registry/ + +THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. +``` diff --git a/dependencies/spirv/include/spirv/1.0/spirv.h b/dependencies/spirv/include/spirv/1.0/spirv.h new file mode 100644 index 000000000000..bd5a9b9593aa --- /dev/null +++ b/dependencies/spirv/include/spirv/1.0/spirv.h @@ -0,0 +1,993 @@ +/* +** Copyright (c) 2014-2018 The Khronos Group Inc. +** +** Permission is hereby granted, free of charge, to any person obtaining a copy +** of this software and/or associated documentation files (the "Materials"), +** to deal in the Materials without restriction, including without limitation +** the rights to use, copy, modify, merge, publish, distribute, sublicense, +** and/or sell copies of the Materials, and to permit persons to whom the +** Materials are furnished to do so, subject to the following conditions: +** +** The above copyright notice and this permission notice shall be included in +** all copies or substantial portions of the Materials. +** +** MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS KHRONOS +** STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS SPECIFICATIONS AND +** HEADER INFORMATION ARE LOCATED AT https://www.khronos.org/registry/ +** +** THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +** OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +** FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +** THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +** LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +** FROM,OUT OF OR IN CONNECTION WITH THE MATERIALS OR THE USE OR OTHER DEALINGS +** IN THE MATERIALS. +*/ + +/* +** This header is automatically generated by the same tool that creates +** the Binary Section of the SPIR-V specification. +*/ + +/* +** Enumeration tokens for SPIR-V, in various styles: +** C, C++, C++11, JSON, Lua, Python +** +** - C will have tokens with a "Spv" prefix, e.g.: SpvSourceLanguageGLSL +** - C++ will have tokens in the "spv" name space, e.g.: spv::SourceLanguageGLSL +** - C++11 will use enum classes in the spv namespace, e.g.: spv::SourceLanguage::GLSL +** - Lua will use tables, e.g.: spv.SourceLanguage.GLSL +** - Python will use dictionaries, e.g.: spv['SourceLanguage']['GLSL'] +** +** Some tokens act like mask values, which can be OR'd together, +** while others are mutually exclusive. The mask-like ones have +** "Mask" in their name, and a parallel enum that has the shift +** amount (1 << x) for each corresponding enumerant. +*/ + +#ifndef spirv_H +#define spirv_H + +typedef unsigned int SpvId; + +#define SPV_VERSION 0x10000 +#define SPV_REVISION 12 + +static const unsigned int SpvMagicNumber = 0x07230203; +static const unsigned int SpvVersion = 0x00010000; +static const unsigned int SpvRevision = 12; +static const unsigned int SpvOpCodeMask = 0xffff; +static const unsigned int SpvWordCountShift = 16; + +typedef enum SpvSourceLanguage_ { + SpvSourceLanguageUnknown = 0, + SpvSourceLanguageESSL = 1, + SpvSourceLanguageGLSL = 2, + SpvSourceLanguageOpenCL_C = 3, + SpvSourceLanguageOpenCL_CPP = 4, + SpvSourceLanguageHLSL = 5, + SpvSourceLanguageMax = 0x7fffffff, +} SpvSourceLanguage; + +typedef enum SpvExecutionModel_ { + SpvExecutionModelVertex = 0, + SpvExecutionModelTessellationControl = 1, + SpvExecutionModelTessellationEvaluation = 2, + SpvExecutionModelGeometry = 3, + SpvExecutionModelFragment = 4, + SpvExecutionModelGLCompute = 5, + SpvExecutionModelKernel = 6, + SpvExecutionModelMax = 0x7fffffff, +} SpvExecutionModel; + +typedef enum SpvAddressingModel_ { + SpvAddressingModelLogical = 0, + SpvAddressingModelPhysical32 = 1, + SpvAddressingModelPhysical64 = 2, + SpvAddressingModelMax = 0x7fffffff, +} SpvAddressingModel; + +typedef enum SpvMemoryModel_ { + SpvMemoryModelSimple = 0, + SpvMemoryModelGLSL450 = 1, + SpvMemoryModelOpenCL = 2, + SpvMemoryModelMax = 0x7fffffff, +} SpvMemoryModel; + +typedef enum SpvExecutionMode_ { + SpvExecutionModeInvocations = 0, + SpvExecutionModeSpacingEqual = 1, + SpvExecutionModeSpacingFractionalEven = 2, + SpvExecutionModeSpacingFractionalOdd = 3, + SpvExecutionModeVertexOrderCw = 4, + SpvExecutionModeVertexOrderCcw = 5, + SpvExecutionModePixelCenterInteger = 6, + SpvExecutionModeOriginUpperLeft = 7, + SpvExecutionModeOriginLowerLeft = 8, + SpvExecutionModeEarlyFragmentTests = 9, + SpvExecutionModePointMode = 10, + SpvExecutionModeXfb = 11, + SpvExecutionModeDepthReplacing = 12, + SpvExecutionModeDepthGreater = 14, + SpvExecutionModeDepthLess = 15, + SpvExecutionModeDepthUnchanged = 16, + SpvExecutionModeLocalSize = 17, + SpvExecutionModeLocalSizeHint = 18, + SpvExecutionModeInputPoints = 19, + SpvExecutionModeInputLines = 20, + SpvExecutionModeInputLinesAdjacency = 21, + SpvExecutionModeTriangles = 22, + SpvExecutionModeInputTrianglesAdjacency = 23, + SpvExecutionModeQuads = 24, + SpvExecutionModeIsolines = 25, + SpvExecutionModeOutputVertices = 26, + SpvExecutionModeOutputPoints = 27, + SpvExecutionModeOutputLineStrip = 28, + SpvExecutionModeOutputTriangleStrip = 29, + SpvExecutionModeVecTypeHint = 30, + SpvExecutionModeContractionOff = 31, + SpvExecutionModePostDepthCoverage = 4446, + SpvExecutionModeStencilRefReplacingEXT = 5027, + SpvExecutionModeMax = 0x7fffffff, +} SpvExecutionMode; + +typedef enum SpvStorageClass_ { + SpvStorageClassUniformConstant = 0, + SpvStorageClassInput = 1, + SpvStorageClassUniform = 2, + SpvStorageClassOutput = 3, + SpvStorageClassWorkgroup = 4, + SpvStorageClassCrossWorkgroup = 5, + SpvStorageClassPrivate = 6, + SpvStorageClassFunction = 7, + SpvStorageClassGeneric = 8, + SpvStorageClassPushConstant = 9, + SpvStorageClassAtomicCounter = 10, + SpvStorageClassImage = 11, + SpvStorageClassStorageBuffer = 12, + SpvStorageClassMax = 0x7fffffff, +} SpvStorageClass; + +typedef enum SpvDim_ { + SpvDim1D = 0, + SpvDim2D = 1, + SpvDim3D = 2, + SpvDimCube = 3, + SpvDimRect = 4, + SpvDimBuffer = 5, + SpvDimSubpassData = 6, + SpvDimMax = 0x7fffffff, +} SpvDim; + +typedef enum SpvSamplerAddressingMode_ { + SpvSamplerAddressingModeNone = 0, + SpvSamplerAddressingModeClampToEdge = 1, + SpvSamplerAddressingModeClamp = 2, + SpvSamplerAddressingModeRepeat = 3, + SpvSamplerAddressingModeRepeatMirrored = 4, + SpvSamplerAddressingModeMax = 0x7fffffff, +} SpvSamplerAddressingMode; + +typedef enum SpvSamplerFilterMode_ { + SpvSamplerFilterModeNearest = 0, + SpvSamplerFilterModeLinear = 1, + SpvSamplerFilterModeMax = 0x7fffffff, +} SpvSamplerFilterMode; + +typedef enum SpvImageFormat_ { + SpvImageFormatUnknown = 0, + SpvImageFormatRgba32f = 1, + SpvImageFormatRgba16f = 2, + SpvImageFormatR32f = 3, + SpvImageFormatRgba8 = 4, + SpvImageFormatRgba8Snorm = 5, + SpvImageFormatRg32f = 6, + SpvImageFormatRg16f = 7, + SpvImageFormatR11fG11fB10f = 8, + SpvImageFormatR16f = 9, + SpvImageFormatRgba16 = 10, + SpvImageFormatRgb10A2 = 11, + SpvImageFormatRg16 = 12, + SpvImageFormatRg8 = 13, + SpvImageFormatR16 = 14, + SpvImageFormatR8 = 15, + SpvImageFormatRgba16Snorm = 16, + SpvImageFormatRg16Snorm = 17, + SpvImageFormatRg8Snorm = 18, + SpvImageFormatR16Snorm = 19, + SpvImageFormatR8Snorm = 20, + SpvImageFormatRgba32i = 21, + SpvImageFormatRgba16i = 22, + SpvImageFormatRgba8i = 23, + SpvImageFormatR32i = 24, + SpvImageFormatRg32i = 25, + SpvImageFormatRg16i = 26, + SpvImageFormatRg8i = 27, + SpvImageFormatR16i = 28, + SpvImageFormatR8i = 29, + SpvImageFormatRgba32ui = 30, + SpvImageFormatRgba16ui = 31, + SpvImageFormatRgba8ui = 32, + SpvImageFormatR32ui = 33, + SpvImageFormatRgb10a2ui = 34, + SpvImageFormatRg32ui = 35, + SpvImageFormatRg16ui = 36, + SpvImageFormatRg8ui = 37, + SpvImageFormatR16ui = 38, + SpvImageFormatR8ui = 39, + SpvImageFormatMax = 0x7fffffff, +} SpvImageFormat; + +typedef enum SpvImageChannelOrder_ { + SpvImageChannelOrderR = 0, + SpvImageChannelOrderA = 1, + SpvImageChannelOrderRG = 2, + SpvImageChannelOrderRA = 3, + SpvImageChannelOrderRGB = 4, + SpvImageChannelOrderRGBA = 5, + SpvImageChannelOrderBGRA = 6, + SpvImageChannelOrderARGB = 7, + SpvImageChannelOrderIntensity = 8, + SpvImageChannelOrderLuminance = 9, + SpvImageChannelOrderRx = 10, + SpvImageChannelOrderRGx = 11, + SpvImageChannelOrderRGBx = 12, + SpvImageChannelOrderDepth = 13, + SpvImageChannelOrderDepthStencil = 14, + SpvImageChannelOrdersRGB = 15, + SpvImageChannelOrdersRGBx = 16, + SpvImageChannelOrdersRGBA = 17, + SpvImageChannelOrdersBGRA = 18, + SpvImageChannelOrderABGR = 19, + SpvImageChannelOrderMax = 0x7fffffff, +} SpvImageChannelOrder; + +typedef enum SpvImageChannelDataType_ { + SpvImageChannelDataTypeSnormInt8 = 0, + SpvImageChannelDataTypeSnormInt16 = 1, + SpvImageChannelDataTypeUnormInt8 = 2, + SpvImageChannelDataTypeUnormInt16 = 3, + SpvImageChannelDataTypeUnormShort565 = 4, + SpvImageChannelDataTypeUnormShort555 = 5, + SpvImageChannelDataTypeUnormInt101010 = 6, + SpvImageChannelDataTypeSignedInt8 = 7, + SpvImageChannelDataTypeSignedInt16 = 8, + SpvImageChannelDataTypeSignedInt32 = 9, + SpvImageChannelDataTypeUnsignedInt8 = 10, + SpvImageChannelDataTypeUnsignedInt16 = 11, + SpvImageChannelDataTypeUnsignedInt32 = 12, + SpvImageChannelDataTypeHalfFloat = 13, + SpvImageChannelDataTypeFloat = 14, + SpvImageChannelDataTypeUnormInt24 = 15, + SpvImageChannelDataTypeUnormInt101010_2 = 16, + SpvImageChannelDataTypeMax = 0x7fffffff, +} SpvImageChannelDataType; + +typedef enum SpvImageOperandsShift_ { + SpvImageOperandsBiasShift = 0, + SpvImageOperandsLodShift = 1, + SpvImageOperandsGradShift = 2, + SpvImageOperandsConstOffsetShift = 3, + SpvImageOperandsOffsetShift = 4, + SpvImageOperandsConstOffsetsShift = 5, + SpvImageOperandsSampleShift = 6, + SpvImageOperandsMinLodShift = 7, + SpvImageOperandsMax = 0x7fffffff, +} SpvImageOperandsShift; + +typedef enum SpvImageOperandsMask_ { + SpvImageOperandsMaskNone = 0, + SpvImageOperandsBiasMask = 0x00000001, + SpvImageOperandsLodMask = 0x00000002, + SpvImageOperandsGradMask = 0x00000004, + SpvImageOperandsConstOffsetMask = 0x00000008, + SpvImageOperandsOffsetMask = 0x00000010, + SpvImageOperandsConstOffsetsMask = 0x00000020, + SpvImageOperandsSampleMask = 0x00000040, + SpvImageOperandsMinLodMask = 0x00000080, +} SpvImageOperandsMask; + +typedef enum SpvFPFastMathModeShift_ { + SpvFPFastMathModeNotNaNShift = 0, + SpvFPFastMathModeNotInfShift = 1, + SpvFPFastMathModeNSZShift = 2, + SpvFPFastMathModeAllowRecipShift = 3, + SpvFPFastMathModeFastShift = 4, + SpvFPFastMathModeMax = 0x7fffffff, +} SpvFPFastMathModeShift; + +typedef enum SpvFPFastMathModeMask_ { + SpvFPFastMathModeMaskNone = 0, + SpvFPFastMathModeNotNaNMask = 0x00000001, + SpvFPFastMathModeNotInfMask = 0x00000002, + SpvFPFastMathModeNSZMask = 0x00000004, + SpvFPFastMathModeAllowRecipMask = 0x00000008, + SpvFPFastMathModeFastMask = 0x00000010, +} SpvFPFastMathModeMask; + +typedef enum SpvFPRoundingMode_ { + SpvFPRoundingModeRTE = 0, + SpvFPRoundingModeRTZ = 1, + SpvFPRoundingModeRTP = 2, + SpvFPRoundingModeRTN = 3, + SpvFPRoundingModeMax = 0x7fffffff, +} SpvFPRoundingMode; + +typedef enum SpvLinkageType_ { + SpvLinkageTypeExport = 0, + SpvLinkageTypeImport = 1, + SpvLinkageTypeMax = 0x7fffffff, +} SpvLinkageType; + +typedef enum SpvAccessQualifier_ { + SpvAccessQualifierReadOnly = 0, + SpvAccessQualifierWriteOnly = 1, + SpvAccessQualifierReadWrite = 2, + SpvAccessQualifierMax = 0x7fffffff, +} SpvAccessQualifier; + +typedef enum SpvFunctionParameterAttribute_ { + SpvFunctionParameterAttributeZext = 0, + SpvFunctionParameterAttributeSext = 1, + SpvFunctionParameterAttributeByVal = 2, + SpvFunctionParameterAttributeSret = 3, + SpvFunctionParameterAttributeNoAlias = 4, + SpvFunctionParameterAttributeNoCapture = 5, + SpvFunctionParameterAttributeNoWrite = 6, + SpvFunctionParameterAttributeNoReadWrite = 7, + SpvFunctionParameterAttributeMax = 0x7fffffff, +} SpvFunctionParameterAttribute; + +typedef enum SpvDecoration_ { + SpvDecorationRelaxedPrecision = 0, + SpvDecorationSpecId = 1, + SpvDecorationBlock = 2, + SpvDecorationBufferBlock = 3, + SpvDecorationRowMajor = 4, + SpvDecorationColMajor = 5, + SpvDecorationArrayStride = 6, + SpvDecorationMatrixStride = 7, + SpvDecorationGLSLShared = 8, + SpvDecorationGLSLPacked = 9, + SpvDecorationCPacked = 10, + SpvDecorationBuiltIn = 11, + SpvDecorationNoPerspective = 13, + SpvDecorationFlat = 14, + SpvDecorationPatch = 15, + SpvDecorationCentroid = 16, + SpvDecorationSample = 17, + SpvDecorationInvariant = 18, + SpvDecorationRestrict = 19, + SpvDecorationAliased = 20, + SpvDecorationVolatile = 21, + SpvDecorationConstant = 22, + SpvDecorationCoherent = 23, + SpvDecorationNonWritable = 24, + SpvDecorationNonReadable = 25, + SpvDecorationUniform = 26, + SpvDecorationSaturatedConversion = 28, + SpvDecorationStream = 29, + SpvDecorationLocation = 30, + SpvDecorationComponent = 31, + SpvDecorationIndex = 32, + SpvDecorationBinding = 33, + SpvDecorationDescriptorSet = 34, + SpvDecorationOffset = 35, + SpvDecorationXfbBuffer = 36, + SpvDecorationXfbStride = 37, + SpvDecorationFuncParamAttr = 38, + SpvDecorationFPRoundingMode = 39, + SpvDecorationFPFastMathMode = 40, + SpvDecorationLinkageAttributes = 41, + SpvDecorationNoContraction = 42, + SpvDecorationInputAttachmentIndex = 43, + SpvDecorationAlignment = 44, + SpvDecorationExplicitInterpAMD = 4999, + SpvDecorationOverrideCoverageNV = 5248, + SpvDecorationPassthroughNV = 5250, + SpvDecorationViewportRelativeNV = 5252, + SpvDecorationSecondaryViewportRelativeNV = 5256, + SpvDecorationHlslCounterBufferGOOGLE = 5634, + SpvDecorationHlslSemanticGOOGLE = 5635, + SpvDecorationMax = 0x7fffffff, +} SpvDecoration; + +typedef enum SpvBuiltIn_ { + SpvBuiltInPosition = 0, + SpvBuiltInPointSize = 1, + SpvBuiltInClipDistance = 3, + SpvBuiltInCullDistance = 4, + SpvBuiltInVertexId = 5, + SpvBuiltInInstanceId = 6, + SpvBuiltInPrimitiveId = 7, + SpvBuiltInInvocationId = 8, + SpvBuiltInLayer = 9, + SpvBuiltInViewportIndex = 10, + SpvBuiltInTessLevelOuter = 11, + SpvBuiltInTessLevelInner = 12, + SpvBuiltInTessCoord = 13, + SpvBuiltInPatchVertices = 14, + SpvBuiltInFragCoord = 15, + SpvBuiltInPointCoord = 16, + SpvBuiltInFrontFacing = 17, + SpvBuiltInSampleId = 18, + SpvBuiltInSamplePosition = 19, + SpvBuiltInSampleMask = 20, + SpvBuiltInFragDepth = 22, + SpvBuiltInHelperInvocation = 23, + SpvBuiltInNumWorkgroups = 24, + SpvBuiltInWorkgroupSize = 25, + SpvBuiltInWorkgroupId = 26, + SpvBuiltInLocalInvocationId = 27, + SpvBuiltInGlobalInvocationId = 28, + SpvBuiltInLocalInvocationIndex = 29, + SpvBuiltInWorkDim = 30, + SpvBuiltInGlobalSize = 31, + SpvBuiltInEnqueuedWorkgroupSize = 32, + SpvBuiltInGlobalOffset = 33, + SpvBuiltInGlobalLinearId = 34, + SpvBuiltInSubgroupSize = 36, + SpvBuiltInSubgroupMaxSize = 37, + SpvBuiltInNumSubgroups = 38, + SpvBuiltInNumEnqueuedSubgroups = 39, + SpvBuiltInSubgroupId = 40, + SpvBuiltInSubgroupLocalInvocationId = 41, + SpvBuiltInVertexIndex = 42, + SpvBuiltInInstanceIndex = 43, + SpvBuiltInSubgroupEqMaskKHR = 4416, + SpvBuiltInSubgroupGeMaskKHR = 4417, + SpvBuiltInSubgroupGtMaskKHR = 4418, + SpvBuiltInSubgroupLeMaskKHR = 4419, + SpvBuiltInSubgroupLtMaskKHR = 4420, + SpvBuiltInBaseVertex = 4424, + SpvBuiltInBaseInstance = 4425, + SpvBuiltInDrawIndex = 4426, + SpvBuiltInDeviceIndex = 4438, + SpvBuiltInViewIndex = 4440, + SpvBuiltInBaryCoordNoPerspAMD = 4992, + SpvBuiltInBaryCoordNoPerspCentroidAMD = 4993, + SpvBuiltInBaryCoordNoPerspSampleAMD = 4994, + SpvBuiltInBaryCoordSmoothAMD = 4995, + SpvBuiltInBaryCoordSmoothCentroidAMD = 4996, + SpvBuiltInBaryCoordSmoothSampleAMD = 4997, + SpvBuiltInBaryCoordPullModelAMD = 4998, + SpvBuiltInFragStencilRefEXT = 5014, + SpvBuiltInViewportMaskNV = 5253, + SpvBuiltInSecondaryPositionNV = 5257, + SpvBuiltInSecondaryViewportMaskNV = 5258, + SpvBuiltInPositionPerViewNV = 5261, + SpvBuiltInViewportMaskPerViewNV = 5262, + SpvBuiltInMax = 0x7fffffff, +} SpvBuiltIn; + +typedef enum SpvSelectionControlShift_ { + SpvSelectionControlFlattenShift = 0, + SpvSelectionControlDontFlattenShift = 1, + SpvSelectionControlMax = 0x7fffffff, +} SpvSelectionControlShift; + +typedef enum SpvSelectionControlMask_ { + SpvSelectionControlMaskNone = 0, + SpvSelectionControlFlattenMask = 0x00000001, + SpvSelectionControlDontFlattenMask = 0x00000002, +} SpvSelectionControlMask; + +typedef enum SpvLoopControlShift_ { + SpvLoopControlUnrollShift = 0, + SpvLoopControlDontUnrollShift = 1, + SpvLoopControlMax = 0x7fffffff, +} SpvLoopControlShift; + +typedef enum SpvLoopControlMask_ { + SpvLoopControlMaskNone = 0, + SpvLoopControlUnrollMask = 0x00000001, + SpvLoopControlDontUnrollMask = 0x00000002, +} SpvLoopControlMask; + +typedef enum SpvFunctionControlShift_ { + SpvFunctionControlInlineShift = 0, + SpvFunctionControlDontInlineShift = 1, + SpvFunctionControlPureShift = 2, + SpvFunctionControlConstShift = 3, + SpvFunctionControlMax = 0x7fffffff, +} SpvFunctionControlShift; + +typedef enum SpvFunctionControlMask_ { + SpvFunctionControlMaskNone = 0, + SpvFunctionControlInlineMask = 0x00000001, + SpvFunctionControlDontInlineMask = 0x00000002, + SpvFunctionControlPureMask = 0x00000004, + SpvFunctionControlConstMask = 0x00000008, +} SpvFunctionControlMask; + +typedef enum SpvMemorySemanticsShift_ { + SpvMemorySemanticsAcquireShift = 1, + SpvMemorySemanticsReleaseShift = 2, + SpvMemorySemanticsAcquireReleaseShift = 3, + SpvMemorySemanticsSequentiallyConsistentShift = 4, + SpvMemorySemanticsUniformMemoryShift = 6, + SpvMemorySemanticsSubgroupMemoryShift = 7, + SpvMemorySemanticsWorkgroupMemoryShift = 8, + SpvMemorySemanticsCrossWorkgroupMemoryShift = 9, + SpvMemorySemanticsAtomicCounterMemoryShift = 10, + SpvMemorySemanticsImageMemoryShift = 11, + SpvMemorySemanticsMax = 0x7fffffff, +} SpvMemorySemanticsShift; + +typedef enum SpvMemorySemanticsMask_ { + SpvMemorySemanticsMaskNone = 0, + SpvMemorySemanticsAcquireMask = 0x00000002, + SpvMemorySemanticsReleaseMask = 0x00000004, + SpvMemorySemanticsAcquireReleaseMask = 0x00000008, + SpvMemorySemanticsSequentiallyConsistentMask = 0x00000010, + SpvMemorySemanticsUniformMemoryMask = 0x00000040, + SpvMemorySemanticsSubgroupMemoryMask = 0x00000080, + SpvMemorySemanticsWorkgroupMemoryMask = 0x00000100, + SpvMemorySemanticsCrossWorkgroupMemoryMask = 0x00000200, + SpvMemorySemanticsAtomicCounterMemoryMask = 0x00000400, + SpvMemorySemanticsImageMemoryMask = 0x00000800, +} SpvMemorySemanticsMask; + +typedef enum SpvMemoryAccessShift_ { + SpvMemoryAccessVolatileShift = 0, + SpvMemoryAccessAlignedShift = 1, + SpvMemoryAccessNontemporalShift = 2, + SpvMemoryAccessMax = 0x7fffffff, +} SpvMemoryAccessShift; + +typedef enum SpvMemoryAccessMask_ { + SpvMemoryAccessMaskNone = 0, + SpvMemoryAccessVolatileMask = 0x00000001, + SpvMemoryAccessAlignedMask = 0x00000002, + SpvMemoryAccessNontemporalMask = 0x00000004, +} SpvMemoryAccessMask; + +typedef enum SpvScope_ { + SpvScopeCrossDevice = 0, + SpvScopeDevice = 1, + SpvScopeWorkgroup = 2, + SpvScopeSubgroup = 3, + SpvScopeInvocation = 4, + SpvScopeMax = 0x7fffffff, +} SpvScope; + +typedef enum SpvGroupOperation_ { + SpvGroupOperationReduce = 0, + SpvGroupOperationInclusiveScan = 1, + SpvGroupOperationExclusiveScan = 2, + SpvGroupOperationMax = 0x7fffffff, +} SpvGroupOperation; + +typedef enum SpvKernelEnqueueFlags_ { + SpvKernelEnqueueFlagsNoWait = 0, + SpvKernelEnqueueFlagsWaitKernel = 1, + SpvKernelEnqueueFlagsWaitWorkGroup = 2, + SpvKernelEnqueueFlagsMax = 0x7fffffff, +} SpvKernelEnqueueFlags; + +typedef enum SpvKernelProfilingInfoShift_ { + SpvKernelProfilingInfoCmdExecTimeShift = 0, + SpvKernelProfilingInfoMax = 0x7fffffff, +} SpvKernelProfilingInfoShift; + +typedef enum SpvKernelProfilingInfoMask_ { + SpvKernelProfilingInfoMaskNone = 0, + SpvKernelProfilingInfoCmdExecTimeMask = 0x00000001, +} SpvKernelProfilingInfoMask; + +typedef enum SpvCapability_ { + SpvCapabilityMatrix = 0, + SpvCapabilityShader = 1, + SpvCapabilityGeometry = 2, + SpvCapabilityTessellation = 3, + SpvCapabilityAddresses = 4, + SpvCapabilityLinkage = 5, + SpvCapabilityKernel = 6, + SpvCapabilityVector16 = 7, + SpvCapabilityFloat16Buffer = 8, + SpvCapabilityFloat16 = 9, + SpvCapabilityFloat64 = 10, + SpvCapabilityInt64 = 11, + SpvCapabilityInt64Atomics = 12, + SpvCapabilityImageBasic = 13, + SpvCapabilityImageReadWrite = 14, + SpvCapabilityImageMipmap = 15, + SpvCapabilityPipes = 17, + SpvCapabilityGroups = 18, + SpvCapabilityDeviceEnqueue = 19, + SpvCapabilityLiteralSampler = 20, + SpvCapabilityAtomicStorage = 21, + SpvCapabilityInt16 = 22, + SpvCapabilityTessellationPointSize = 23, + SpvCapabilityGeometryPointSize = 24, + SpvCapabilityImageGatherExtended = 25, + SpvCapabilityStorageImageMultisample = 27, + SpvCapabilityUniformBufferArrayDynamicIndexing = 28, + SpvCapabilitySampledImageArrayDynamicIndexing = 29, + SpvCapabilityStorageBufferArrayDynamicIndexing = 30, + SpvCapabilityStorageImageArrayDynamicIndexing = 31, + SpvCapabilityClipDistance = 32, + SpvCapabilityCullDistance = 33, + SpvCapabilityImageCubeArray = 34, + SpvCapabilitySampleRateShading = 35, + SpvCapabilityImageRect = 36, + SpvCapabilitySampledRect = 37, + SpvCapabilityGenericPointer = 38, + SpvCapabilityInt8 = 39, + SpvCapabilityInputAttachment = 40, + SpvCapabilitySparseResidency = 41, + SpvCapabilityMinLod = 42, + SpvCapabilitySampled1D = 43, + SpvCapabilityImage1D = 44, + SpvCapabilitySampledCubeArray = 45, + SpvCapabilitySampledBuffer = 46, + SpvCapabilityImageBuffer = 47, + SpvCapabilityImageMSArray = 48, + SpvCapabilityStorageImageExtendedFormats = 49, + SpvCapabilityImageQuery = 50, + SpvCapabilityDerivativeControl = 51, + SpvCapabilityInterpolationFunction = 52, + SpvCapabilityTransformFeedback = 53, + SpvCapabilityGeometryStreams = 54, + SpvCapabilityStorageImageReadWithoutFormat = 55, + SpvCapabilityStorageImageWriteWithoutFormat = 56, + SpvCapabilityMultiViewport = 57, + SpvCapabilitySubgroupBallotKHR = 4423, + SpvCapabilityDrawParameters = 4427, + SpvCapabilitySubgroupVoteKHR = 4431, + SpvCapabilityStorageBuffer16BitAccess = 4433, + SpvCapabilityStorageUniformBufferBlock16 = 4433, + SpvCapabilityStorageUniform16 = 4434, + SpvCapabilityUniformAndStorageBuffer16BitAccess = 4434, + SpvCapabilityStoragePushConstant16 = 4435, + SpvCapabilityStorageInputOutput16 = 4436, + SpvCapabilityDeviceGroup = 4437, + SpvCapabilityMultiView = 4439, + SpvCapabilityVariablePointersStorageBuffer = 4441, + SpvCapabilityVariablePointers = 4442, + SpvCapabilityAtomicStorageOps = 4445, + SpvCapabilitySampleMaskPostDepthCoverage = 4447, + SpvCapabilityImageGatherBiasLodAMD = 5009, + SpvCapabilityFragmentMaskAMD = 5010, + SpvCapabilityStencilExportEXT = 5013, + SpvCapabilityImageReadWriteLodAMD = 5015, + SpvCapabilitySampleMaskOverrideCoverageNV = 5249, + SpvCapabilityGeometryShaderPassthroughNV = 5251, + SpvCapabilityShaderViewportIndexLayerEXT = 5254, + SpvCapabilityShaderViewportIndexLayerNV = 5254, + SpvCapabilityShaderViewportMaskNV = 5255, + SpvCapabilityShaderStereoViewNV = 5259, + SpvCapabilityPerViewAttributesNV = 5260, + SpvCapabilitySubgroupShuffleINTEL = 5568, + SpvCapabilitySubgroupBufferBlockIOINTEL = 5569, + SpvCapabilitySubgroupImageBlockIOINTEL = 5570, + SpvCapabilityMax = 0x7fffffff, +} SpvCapability; + +typedef enum SpvOp_ { + SpvOpNop = 0, + SpvOpUndef = 1, + SpvOpSourceContinued = 2, + SpvOpSource = 3, + SpvOpSourceExtension = 4, + SpvOpName = 5, + SpvOpMemberName = 6, + SpvOpString = 7, + SpvOpLine = 8, + SpvOpExtension = 10, + SpvOpExtInstImport = 11, + SpvOpExtInst = 12, + SpvOpMemoryModel = 14, + SpvOpEntryPoint = 15, + SpvOpExecutionMode = 16, + SpvOpCapability = 17, + SpvOpTypeVoid = 19, + SpvOpTypeBool = 20, + SpvOpTypeInt = 21, + SpvOpTypeFloat = 22, + SpvOpTypeVector = 23, + SpvOpTypeMatrix = 24, + SpvOpTypeImage = 25, + SpvOpTypeSampler = 26, + SpvOpTypeSampledImage = 27, + SpvOpTypeArray = 28, + SpvOpTypeRuntimeArray = 29, + SpvOpTypeStruct = 30, + SpvOpTypeOpaque = 31, + SpvOpTypePointer = 32, + SpvOpTypeFunction = 33, + SpvOpTypeEvent = 34, + SpvOpTypeDeviceEvent = 35, + SpvOpTypeReserveId = 36, + SpvOpTypeQueue = 37, + SpvOpTypePipe = 38, + SpvOpTypeForwardPointer = 39, + SpvOpConstantTrue = 41, + SpvOpConstantFalse = 42, + SpvOpConstant = 43, + SpvOpConstantComposite = 44, + SpvOpConstantSampler = 45, + SpvOpConstantNull = 46, + SpvOpSpecConstantTrue = 48, + SpvOpSpecConstantFalse = 49, + SpvOpSpecConstant = 50, + SpvOpSpecConstantComposite = 51, + SpvOpSpecConstantOp = 52, + SpvOpFunction = 54, + SpvOpFunctionParameter = 55, + SpvOpFunctionEnd = 56, + SpvOpFunctionCall = 57, + SpvOpVariable = 59, + SpvOpImageTexelPointer = 60, + SpvOpLoad = 61, + SpvOpStore = 62, + SpvOpCopyMemory = 63, + SpvOpCopyMemorySized = 64, + SpvOpAccessChain = 65, + SpvOpInBoundsAccessChain = 66, + SpvOpPtrAccessChain = 67, + SpvOpArrayLength = 68, + SpvOpGenericPtrMemSemantics = 69, + SpvOpInBoundsPtrAccessChain = 70, + SpvOpDecorate = 71, + SpvOpMemberDecorate = 72, + SpvOpDecorationGroup = 73, + SpvOpGroupDecorate = 74, + SpvOpGroupMemberDecorate = 75, + SpvOpVectorExtractDynamic = 77, + SpvOpVectorInsertDynamic = 78, + SpvOpVectorShuffle = 79, + SpvOpCompositeConstruct = 80, + SpvOpCompositeExtract = 81, + SpvOpCompositeInsert = 82, + SpvOpCopyObject = 83, + SpvOpTranspose = 84, + SpvOpSampledImage = 86, + SpvOpImageSampleImplicitLod = 87, + SpvOpImageSampleExplicitLod = 88, + SpvOpImageSampleDrefImplicitLod = 89, + SpvOpImageSampleDrefExplicitLod = 90, + SpvOpImageSampleProjImplicitLod = 91, + SpvOpImageSampleProjExplicitLod = 92, + SpvOpImageSampleProjDrefImplicitLod = 93, + SpvOpImageSampleProjDrefExplicitLod = 94, + SpvOpImageFetch = 95, + SpvOpImageGather = 96, + SpvOpImageDrefGather = 97, + SpvOpImageRead = 98, + SpvOpImageWrite = 99, + SpvOpImage = 100, + SpvOpImageQueryFormat = 101, + SpvOpImageQueryOrder = 102, + SpvOpImageQuerySizeLod = 103, + SpvOpImageQuerySize = 104, + SpvOpImageQueryLod = 105, + SpvOpImageQueryLevels = 106, + SpvOpImageQuerySamples = 107, + SpvOpConvertFToU = 109, + SpvOpConvertFToS = 110, + SpvOpConvertSToF = 111, + SpvOpConvertUToF = 112, + SpvOpUConvert = 113, + SpvOpSConvert = 114, + SpvOpFConvert = 115, + SpvOpQuantizeToF16 = 116, + SpvOpConvertPtrToU = 117, + SpvOpSatConvertSToU = 118, + SpvOpSatConvertUToS = 119, + SpvOpConvertUToPtr = 120, + SpvOpPtrCastToGeneric = 121, + SpvOpGenericCastToPtr = 122, + SpvOpGenericCastToPtrExplicit = 123, + SpvOpBitcast = 124, + SpvOpSNegate = 126, + SpvOpFNegate = 127, + SpvOpIAdd = 128, + SpvOpFAdd = 129, + SpvOpISub = 130, + SpvOpFSub = 131, + SpvOpIMul = 132, + SpvOpFMul = 133, + SpvOpUDiv = 134, + SpvOpSDiv = 135, + SpvOpFDiv = 136, + SpvOpUMod = 137, + SpvOpSRem = 138, + SpvOpSMod = 139, + SpvOpFRem = 140, + SpvOpFMod = 141, + SpvOpVectorTimesScalar = 142, + SpvOpMatrixTimesScalar = 143, + SpvOpVectorTimesMatrix = 144, + SpvOpMatrixTimesVector = 145, + SpvOpMatrixTimesMatrix = 146, + SpvOpOuterProduct = 147, + SpvOpDot = 148, + SpvOpIAddCarry = 149, + SpvOpISubBorrow = 150, + SpvOpUMulExtended = 151, + SpvOpSMulExtended = 152, + SpvOpAny = 154, + SpvOpAll = 155, + SpvOpIsNan = 156, + SpvOpIsInf = 157, + SpvOpIsFinite = 158, + SpvOpIsNormal = 159, + SpvOpSignBitSet = 160, + SpvOpLessOrGreater = 161, + SpvOpOrdered = 162, + SpvOpUnordered = 163, + SpvOpLogicalEqual = 164, + SpvOpLogicalNotEqual = 165, + SpvOpLogicalOr = 166, + SpvOpLogicalAnd = 167, + SpvOpLogicalNot = 168, + SpvOpSelect = 169, + SpvOpIEqual = 170, + SpvOpINotEqual = 171, + SpvOpUGreaterThan = 172, + SpvOpSGreaterThan = 173, + SpvOpUGreaterThanEqual = 174, + SpvOpSGreaterThanEqual = 175, + SpvOpULessThan = 176, + SpvOpSLessThan = 177, + SpvOpULessThanEqual = 178, + SpvOpSLessThanEqual = 179, + SpvOpFOrdEqual = 180, + SpvOpFUnordEqual = 181, + SpvOpFOrdNotEqual = 182, + SpvOpFUnordNotEqual = 183, + SpvOpFOrdLessThan = 184, + SpvOpFUnordLessThan = 185, + SpvOpFOrdGreaterThan = 186, + SpvOpFUnordGreaterThan = 187, + SpvOpFOrdLessThanEqual = 188, + SpvOpFUnordLessThanEqual = 189, + SpvOpFOrdGreaterThanEqual = 190, + SpvOpFUnordGreaterThanEqual = 191, + SpvOpShiftRightLogical = 194, + SpvOpShiftRightArithmetic = 195, + SpvOpShiftLeftLogical = 196, + SpvOpBitwiseOr = 197, + SpvOpBitwiseXor = 198, + SpvOpBitwiseAnd = 199, + SpvOpNot = 200, + SpvOpBitFieldInsert = 201, + SpvOpBitFieldSExtract = 202, + SpvOpBitFieldUExtract = 203, + SpvOpBitReverse = 204, + SpvOpBitCount = 205, + SpvOpDPdx = 207, + SpvOpDPdy = 208, + SpvOpFwidth = 209, + SpvOpDPdxFine = 210, + SpvOpDPdyFine = 211, + SpvOpFwidthFine = 212, + SpvOpDPdxCoarse = 213, + SpvOpDPdyCoarse = 214, + SpvOpFwidthCoarse = 215, + SpvOpEmitVertex = 218, + SpvOpEndPrimitive = 219, + SpvOpEmitStreamVertex = 220, + SpvOpEndStreamPrimitive = 221, + SpvOpControlBarrier = 224, + SpvOpMemoryBarrier = 225, + SpvOpAtomicLoad = 227, + SpvOpAtomicStore = 228, + SpvOpAtomicExchange = 229, + SpvOpAtomicCompareExchange = 230, + SpvOpAtomicCompareExchangeWeak = 231, + SpvOpAtomicIIncrement = 232, + SpvOpAtomicIDecrement = 233, + SpvOpAtomicIAdd = 234, + SpvOpAtomicISub = 235, + SpvOpAtomicSMin = 236, + SpvOpAtomicUMin = 237, + SpvOpAtomicSMax = 238, + SpvOpAtomicUMax = 239, + SpvOpAtomicAnd = 240, + SpvOpAtomicOr = 241, + SpvOpAtomicXor = 242, + SpvOpPhi = 245, + SpvOpLoopMerge = 246, + SpvOpSelectionMerge = 247, + SpvOpLabel = 248, + SpvOpBranch = 249, + SpvOpBranchConditional = 250, + SpvOpSwitch = 251, + SpvOpKill = 252, + SpvOpReturn = 253, + SpvOpReturnValue = 254, + SpvOpUnreachable = 255, + SpvOpLifetimeStart = 256, + SpvOpLifetimeStop = 257, + SpvOpGroupAsyncCopy = 259, + SpvOpGroupWaitEvents = 260, + SpvOpGroupAll = 261, + SpvOpGroupAny = 262, + SpvOpGroupBroadcast = 263, + SpvOpGroupIAdd = 264, + SpvOpGroupFAdd = 265, + SpvOpGroupFMin = 266, + SpvOpGroupUMin = 267, + SpvOpGroupSMin = 268, + SpvOpGroupFMax = 269, + SpvOpGroupUMax = 270, + SpvOpGroupSMax = 271, + SpvOpReadPipe = 274, + SpvOpWritePipe = 275, + SpvOpReservedReadPipe = 276, + SpvOpReservedWritePipe = 277, + SpvOpReserveReadPipePackets = 278, + SpvOpReserveWritePipePackets = 279, + SpvOpCommitReadPipe = 280, + SpvOpCommitWritePipe = 281, + SpvOpIsValidReserveId = 282, + SpvOpGetNumPipePackets = 283, + SpvOpGetMaxPipePackets = 284, + SpvOpGroupReserveReadPipePackets = 285, + SpvOpGroupReserveWritePipePackets = 286, + SpvOpGroupCommitReadPipe = 287, + SpvOpGroupCommitWritePipe = 288, + SpvOpEnqueueMarker = 291, + SpvOpEnqueueKernel = 292, + SpvOpGetKernelNDrangeSubGroupCount = 293, + SpvOpGetKernelNDrangeMaxSubGroupSize = 294, + SpvOpGetKernelWorkGroupSize = 295, + SpvOpGetKernelPreferredWorkGroupSizeMultiple = 296, + SpvOpRetainEvent = 297, + SpvOpReleaseEvent = 298, + SpvOpCreateUserEvent = 299, + SpvOpIsValidEvent = 300, + SpvOpSetUserEventStatus = 301, + SpvOpCaptureEventProfilingInfo = 302, + SpvOpGetDefaultQueue = 303, + SpvOpBuildNDRange = 304, + SpvOpImageSparseSampleImplicitLod = 305, + SpvOpImageSparseSampleExplicitLod = 306, + SpvOpImageSparseSampleDrefImplicitLod = 307, + SpvOpImageSparseSampleDrefExplicitLod = 308, + SpvOpImageSparseSampleProjImplicitLod = 309, + SpvOpImageSparseSampleProjExplicitLod = 310, + SpvOpImageSparseSampleProjDrefImplicitLod = 311, + SpvOpImageSparseSampleProjDrefExplicitLod = 312, + SpvOpImageSparseFetch = 313, + SpvOpImageSparseGather = 314, + SpvOpImageSparseDrefGather = 315, + SpvOpImageSparseTexelsResident = 316, + SpvOpNoLine = 317, + SpvOpAtomicFlagTestAndSet = 318, + SpvOpAtomicFlagClear = 319, + SpvOpImageSparseRead = 320, + SpvOpDecorateId = 332, + SpvOpSubgroupBallotKHR = 4421, + SpvOpSubgroupFirstInvocationKHR = 4422, + SpvOpSubgroupAllKHR = 4428, + SpvOpSubgroupAnyKHR = 4429, + SpvOpSubgroupAllEqualKHR = 4430, + SpvOpSubgroupReadInvocationKHR = 4432, + SpvOpGroupIAddNonUniformAMD = 5000, + SpvOpGroupFAddNonUniformAMD = 5001, + SpvOpGroupFMinNonUniformAMD = 5002, + SpvOpGroupUMinNonUniformAMD = 5003, + SpvOpGroupSMinNonUniformAMD = 5004, + SpvOpGroupFMaxNonUniformAMD = 5005, + SpvOpGroupUMaxNonUniformAMD = 5006, + SpvOpGroupSMaxNonUniformAMD = 5007, + SpvOpFragmentMaskFetchAMD = 5011, + SpvOpFragmentFetchAMD = 5012, + SpvOpSubgroupShuffleINTEL = 5571, + SpvOpSubgroupShuffleDownINTEL = 5572, + SpvOpSubgroupShuffleUpINTEL = 5573, + SpvOpSubgroupShuffleXorINTEL = 5574, + SpvOpSubgroupBlockReadINTEL = 5575, + SpvOpSubgroupBlockWriteINTEL = 5576, + SpvOpSubgroupImageBlockReadINTEL = 5577, + SpvOpSubgroupImageBlockWriteINTEL = 5578, + SpvOpDecorateStringGOOGLE = 5632, + SpvOpMemberDecorateStringGOOGLE = 5633, + SpvOpMax = 0x7fffffff, +} SpvOp; + +#endif // #ifndef spirv_H + diff --git a/dependencies/wasm/CMakeLists.txt b/dependencies/wasm/CMakeLists.txt index 8481109ac1f3..0c9d76007007 100644 --- a/dependencies/wasm/CMakeLists.txt +++ b/dependencies/wasm/CMakeLists.txt @@ -16,7 +16,7 @@ if ("${CMAKE_HOST_SYSTEM_NAME}" STREQUAL "Windows") endif () if (WITH_WABT) - set(WABT_VER 1.0.27) + set(WABT_VER 1.0.29) message(STATUS "Fetching WABT ${WABT_VER}...") FetchContent_Declare(wabt @@ -87,11 +87,7 @@ function(add_wasm_executable TARGET) # target_link_libraries(${TARGET} PRIVATE ${args_DEPS}) # endif () - find_program(EMCC emcc HINTS "$ENV{EMSDK}/upstream/emscripten") - - if (NOT EMCC) - message(FATAL_ERROR "Building tests or apps for WASM requires that EMSDK point to a valid Emscripten install.") - endif () + find_program(EMCC emcc REQUIRED HINTS "$ENV{EMSDK}/upstream/emscripten") # TODO: this is currently hardcoded to settings that are sensible for most of Halide's # internal purposes. Consider adding ways to customize this as appropriate. @@ -155,12 +151,7 @@ function(add_wasm_halide_test TARGET) endfunction() function(find_node_js) - find_program(NODE_JS_EXECUTABLE node nodejs) - - # TODO: when we eventually upgrade to CMake >= 3.18, replace with REQUIRED in find_program - if (NOT NODE_JS_EXECUTABLE) - message(FATAL_ERROR "Could not find nodejs. Please set NODE_JS_EXECUTABLE on the CMake command line.") - endif () + find_program(NODE_JS_EXECUTABLE node nodejs REQUIRED) execute_process(COMMAND "${NODE_JS_EXECUTABLE}" --version OUTPUT_VARIABLE NODE_JS_VERSION_RAW diff --git a/packaging/CMakeLists.txt b/packaging/CMakeLists.txt index 35067a954c25..a68d1c153df6 100644 --- a/packaging/CMakeLists.txt +++ b/packaging/CMakeLists.txt @@ -45,18 +45,6 @@ foreach (dep IN ITEMS Halide_LLVM Halide_wabt) endif () endforeach () -## -# Python bindings -## - -if (WITH_PYTHON_BINDINGS) - set(Halide_INSTALL_PYTHONDIR "${CMAKE_INSTALL_LIBDIR}/python3/site-packages" - CACHE STRING "Path to Halide Python bindings folder") - install(TARGETS Halide_Python - LIBRARY DESTINATION ${Halide_INSTALL_PYTHONDIR} COMPONENT Halide_Python - NAMELINK_COMPONENT Halide_Python) -endif () - ## # Library-type-agnostic interface targets ## @@ -129,12 +117,8 @@ install(DIRECTORY ${Halide_SOURCE_DIR}/tools/ PATTERN "build_halide_h.cpp" EXCLUDE PATTERN "find_inverse.cpp" EXCLUDE) -install(FILES ${Halide_SOURCE_DIR}/src/autoschedulers/adams2019/autotune_loop.sh +install(PROGRAMS ${Halide_SOURCE_DIR}/src/autoschedulers/adams2019/autotune_loop.sh DESTINATION ${Halide_INSTALL_TOOLSDIR} - PERMISSIONS - OWNER_READ OWNER_WRITE OWNER_EXECUTE - GROUP_READ GROUP_EXECUTE - WORLD_READ WORLD_EXECUTE COMPONENT Halide_Development) ## @@ -153,13 +137,6 @@ if (WITH_TUTORIALS) PATTERN "*.jpg" PATTERN "*.mp4" PATTERN "*.png") - - if (WITH_PYTHON_BINDINGS) - install(DIRECTORY ${Halide_SOURCE_DIR}/python_bindings/tutorial/ - DESTINATION ${CMAKE_INSTALL_DOCDIR}/tutorial-python - COMPONENT Halide_Documentation - FILES_MATCHING PATTERN "*.py") - endif () endif () ## @@ -268,6 +245,11 @@ cpack_add_component(Halide_Development DESCRIPTION "Static Halide libraries and CMake development files" DEPENDS Halide_Runtime) +cpack_add_component(Halide_Python + DISPLAY_NAME "Python bindings" + DESCRIPTION "Python package providing bindings to Halide" + DEPENDS Halide_Runtime) + cpack_add_component(Halide_Documentation DISPLAY_NAME "Halide documentation" DESCRIPTION "Documentation for Halide") diff --git a/packaging/common/HalideConfig.cmake b/packaging/common/HalideConfig.cmake index 1a380a0f326b..729e938f971b 100644 --- a/packaging/common/HalideConfig.cmake +++ b/packaging/common/HalideConfig.cmake @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) macro(Halide_fail message) set(${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE "${message}") diff --git a/packaging/common/HalideHelpersConfig.cmake b/packaging/common/HalideHelpersConfig.cmake index b35b491af5b5..739c465fdd55 100644 --- a/packaging/common/HalideHelpersConfig.cmake +++ b/packaging/common/HalideHelpersConfig.cmake @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) set(Halide_HOST_TARGET @Halide_HOST_TARGET@) diff --git a/packaging/ubuntu/changelog b/packaging/ubuntu/changelog deleted file mode 100644 index f36c015c23e5..000000000000 --- a/packaging/ubuntu/changelog +++ /dev/null @@ -1,5 +0,0 @@ -@package_name@ (@CPACK_PACKAGE_VERSION@) UNRELEASED; urgency=low - - * Initial package release. - - -- @CPACK_PACKAGE_CONTACT@ @timestamp@ -0000 diff --git a/packaging/ubuntu/config.cmake b/packaging/ubuntu/config.cmake deleted file mode 100644 index 722cf47852a2..000000000000 --- a/packaging/ubuntu/config.cmake +++ /dev/null @@ -1,137 +0,0 @@ -cmake_minimum_required(VERSION 3.19) - -include("shared-Release/CPackConfig.cmake") - -## General setup - -set(CPACK_PACKAGE_CONTACT "Alex Reinking ") -set(CPACK_STRIP_FILES TRUE) -set(CPACK_PRE_BUILD_SCRIPTS "${CMAKE_CURRENT_LIST_DIR}/pre_build.cmake") - -############################## -## Components configuration ## -############################## - -# This is a mapping from CPack component names to CMake install() components. -# We use the identity mapping here for simplicity; some advanced configurations -# with GUI installers require these to diverge. -set(CPACK_COMPONENTS_HALIDE_RUNTIME Halide_Runtime) -set(CPACK_COMPONENTS_HALIDE_DEVELOPMENT Halide_Development) -set(CPACK_COMPONENTS_HALIDE_DOCUMENTATION Halide_Documentation) - -set(CPACK_COMPONENTS_ALL Halide_Runtime Halide_Development Halide_Documentation) - -set(CPACK_INSTALL_CMAKE_PROJECTS - static-Release Halide ALL / - shared-Release Halide ALL /) - -################################### -## Ubuntu-specific configuration ## -################################### - -# We set every variable documented here: https://cmake.org/cmake/help/latest/cpack_gen/deb.html -# even if it's just to the default. That way there are no surprises. - -set(CPACK_DEB_COMPONENT_INSTALL YES) - -set(CPACK_DEBIAN_HALIDE_RUNTIME_PACKAGE_NAME libHalide${CPACK_PACKAGE_VERSION_MAJOR}) -set(CPACK_DEBIAN_HALIDE_DEVELOPMENT_PACKAGE_NAME libHalide${CPACK_PACKAGE_VERSION_MAJOR}-dev) -set(CPACK_DEBIAN_HALIDE_DOCUMENTATION_PACKAGE_NAME libHalide${CPACK_PACKAGE_VERSION_MAJOR}-doc) - -set(CPACK_DEBIAN_HALIDE_RUNTIME_FILE_NAME DEB-DEFAULT) -set(CPACK_DEBIAN_HALIDE_DEVELOPMENT_FILE_NAME DEB-DEFAULT) -set(CPACK_DEBIAN_HALIDE_DOCUMENTATION_FILE_NAME DEB-DEFAULT) - -# Debian package versions look like: :- -# is a number that increases when changing the whole versioning schema. -# We would ideally _never_ have to set this since we're using semver. -# is the version number of the actual software being packaged. -# is the version number of the _package_. Set/increment this when fixing -# bugs in the package itself. This should also not be incremented too -# frequently. It's always safe to bump the patch version when in doubt. -unset(CPACK_DEBIAN_PACKAGE_EPOCH) -set(CPACK_DEBIAN_PACKAGE_VERSION "${CPACK_PACKAGE_VERSION}") -unset(CPACK_DEBIAN_PACKAGE_RELEASE) - -# The default here is the host system architecture. It will generally be best -# to package for ARM on ARM, for x86 on x86, etc. The documentation gets the -# pseudo-architecture "all" to indicate that it has no binaries (ie. is arch -# independent). -unset(CPACK_DEBIAN_PACKAGE_ARCHITECTURE) -set(CPACK_DEBIAN_HALIDE_DOCUMENTATION_PACKAGE_ARCHITECTURE all) - -# Package dependencies. -# TODO: figure out how to get LLVM major version piped in here. -set(CPACK_DEBIAN_HALIDE_RUNTIME_PACKAGE_DEPENDS "llvm-12 (>= 12.0.0)") -set(CPACK_DEBIAN_HALIDE_DEVELOPMENT_PACKAGE_DEPENDS "llvm-12-dev (>= 12.0.0), liblld-12-dev (>= 12.0.0)") -set(CPACK_DEBIAN_HALIDE_DOCUMENTATION_PACKAGE_DEPENDS "") - -# Sets up package dependencies based on CPack component dependencies -set(CPACK_DEBIAN_ENABLE_COMPONENT_DEPENDS ON) - -# Uses CPACK_PACKAGE_CONTACT as default -unset(CPACK_DEBIAN_PACKAGE_MAINTAINER) - -# These inherit their values from cpack cpack_add_component -unset(CPACK_DEBIAN_HALIDE_RUNTIME_DESCRIPTION) -unset(CPACK_DEBIAN_HALIDE_DEVELOPMENT_DESCRIPTION) -unset(CPACK_DEBIAN_HALIDE_DOCUMENTATION_DESCRIPTION) - -# The Debian repository package section. -# See: https://packages.debian.org/unstable/ -# libs = Libraries to make other programs work. They provide special features to developers. -# libdevel = Libraries necessary for developers to write programs that use them. -# doc = FAQs, HOWTOs and other documents trying to explain everything related to -# Debian, and software needed to browse documentation (man, info, etc). -set(CPACK_DEBIAN_HALIDE_RUNTIME_PACKAGE_SECTION libs) -set(CPACK_DEBIAN_HALIDE_DEVELOPMENT_PACKAGE_SECTION libdevel) -set(CPACK_DEBIAN_HALIDE_DOCUMENTATION_PACKAGE_SECTION doc) - -# Deprecated: do not use -unset(CPACK_DEBIAN_ARCHIVE_TYPE) - -# Could also choose from lzma, xz, or bzip2 if one gave a better ratio. -set(CPACK_DEBIAN_COMPRESSION_TYPE "gzip") - -# Optional just means that it is optional for the safe running of -# a Debian system to have our package installed. The other categories -# do not apply to us: required (won't boot without), important (core -# system utils), and standard (basic niceties for a character-mode -# system). -set(CPACK_DEBIAN_PACKAGE_PRIORITY "optional") - -# Uses CMAKE_PROJECT_HOMEPAGE_URL as default. -unset(CPACK_DEBIAN_PACKAGE_HOMEPAGE) - -# Call dpkg-shlibdeps to get dependencies on system libraries. -set(CPACK_DEBIAN_PACKAGE_SHLIBDEPS ON) -unset(CPACK_DEBIAN_PACKAGE_SHLIBDEPS_PRIVATE_DIRS) # CMake 3.20+ only - -# Disable debug messaging -unset(CPACK_DEBIAN_PACKAGE_DEBUG) - -# Special variables for package constraints. We don't have any yet. -unset(CPACK_DEBIAN_PACKAGE_PREDEPENDS) -unset(CPACK_DEBIAN_PACKAGE_ENHANCES) -unset(CPACK_DEBIAN_PACKAGE_BREAKS) -unset(CPACK_DEBIAN_PACKAGE_CONFLICTS) -unset(CPACK_DEBIAN_PACKAGE_PROVIDES) -unset(CPACK_DEBIAN_PACKAGE_REPLACES) -unset(CPACK_DEBIAN_PACKAGE_RECOMMENDS) -unset(CPACK_DEBIAN_PACKAGE_SUGGESTS) - -# Generate debian/shlibs control file; require exact versions. -set(CPACK_DEBIAN_PACKAGE_GENERATE_SHLIBS YES) -set(CPACK_DEBIAN_PACKAGE_GENERATE_SHLIBS_POLICY "=") - -# Add custom scripts to package. Used to ensure ldconfig runs. -unset(CPACK_DEBIAN_PACKAGE_CONTROL_EXTRA) -set(CPACK_DEBIAN_HALIDE_RUNTIME_PACKAGE_CONTROL_EXTRA - "${CMAKE_CURRENT_LIST_DIR}/triggers") -set(CPACK_DEBIAN_PACKAGE_CONTROL_STRICT_PERMISSION TRUE) - -# Name the source package for this one. TODO? -unset(CPACK_DEBIAN_PACKAGE_SOURCE) - -# Name the package containing debug symbols for this one. TODO? -unset(CPACK_DEBIAN_DEBUGINFO_PACKAGE) diff --git a/packaging/ubuntu/copyright b/packaging/ubuntu/copyright deleted file mode 100644 index f5529a29b208..000000000000 --- a/packaging/ubuntu/copyright +++ /dev/null @@ -1,26 +0,0 @@ -Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ -Upstream-Name: @CPACK_PACKAGE_NAME@ -Upstream-Contact: @CPACK_PACKAGE_CONTACT@ -Source: @CPACK_PACKAGE_HOMEPAGE_URL@ - -Files: * -Copyright: @copyright_line@ -License: MIT - Permission is hereby granted, free of charge, to any person obtaining - a copy of this software and associated documentation files (the - "Software"), to deal in the Software without restriction, including - without limitation the rights to use, copy, modify, merge, publish, - distribute, sublicense, and/or sell copies of the Software, and to - permit persons to whom the Software is furnished to do so, subject - to the following conditions: - . - The above copyright notice and this permission notice shall be included - in all copies or substantial portions of the Software. - . - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES - OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/packaging/ubuntu/extra-strip.sh b/packaging/ubuntu/extra-strip.sh deleted file mode 100755 index 993dbfcf70cc..000000000000 --- a/packaging/ubuntu/extra-strip.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash - -# See https://github.com/Debian/debhelper/blob/5d1bb29841043d8e47ebbdd043e6cd086cad508e/dh_strip#L362-L384 -# for what dh_strip removes. - -strip --remove-section=.comment --remove-section=.note "$@" diff --git a/packaging/ubuntu/package.sh b/packaging/ubuntu/package.sh deleted file mode 100755 index c831d3cfb710..000000000000 --- a/packaging/ubuntu/package.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash -set -e -o pipefail - -halide_source=$(realpath "$1") -halide_build_root=$(realpath "$2") - -[ -z "$halide_source" ] && echo "Usage: $0 " && exit -[ -z "$halide_build_root" ] && echo "Usage: $0 " && exit -[ -z "$LLVM_ROOT" ] && echo "Must set LLVM_ROOT to /usr/lib/llvm-VERSION" && exit - -function group() { - [[ -n "${GITHUB_ACTIONS}" && -n "${SEEN_GROUP}" ]] && echo "::endgroup::" - [[ -n "${GITHUB_ACTIONS}" ]] && echo "::group::$*" - export SEEN_GROUP=1 -} - -group "Configure shared Halide build" -cmake --preset=package-ubuntu-shared -S "$halide_source" -B "$halide_build_root/shared-Release" - -group "Configure static Halide build" -cmake --preset=package-ubuntu-static -S "$halide_source" -B "$halide_build_root/static-Release" - -group "Build shared Halide" -cmake --build "$halide_build_root/shared-Release" -- -v - -group "Build static Halide" -cmake --build "$halide_build_root/static-Release" -- -v - -group "Create Ubuntu packages" -cd "$halide_build_root" -rm -rf ./_CPack_Packages ./*.deb lintian.log -umask 0022 -export LD_LIBRARY_PATH="$halide_build_root/shared-Release/src" - -cpack -G DEB -C Release --config "$halide_source/packaging/ubuntu/config.cmake" - -# Lintian: https://lintian.debian.org/tags - -group "Run strict Lintian checks" -lintian --no-tag-display-limit -i ./*.deb - -group "Run extra Lintian checks" -lintian --no-tag-display-limit -L "=info" -i ./*.deb - -echo "Success!" diff --git a/packaging/ubuntu/pre_build.cmake b/packaging/ubuntu/pre_build.cmake deleted file mode 100644 index a74370d1fd70..000000000000 --- a/packaging/ubuntu/pre_build.cmake +++ /dev/null @@ -1,25 +0,0 @@ -cmake_minimum_required(VERSION 3.19) - -file(STRINGS "${CPACK_RESOURCE_FILE_LICENSE}" copyright_line LIMIT_COUNT 1) -string(TIMESTAMP timestamp "%a, %d %b %Y %H:%M:%S" UTC) - -find_program(GZIP gzip) -if (NOT GZIP) - message(FATAL_ERROR "Could not find gzip") -endif () - -foreach (comp IN LISTS CPACK_COMPONENTS_ALL) - string(TOUPPER "CPACK_DEBIAN_${comp}_PACKAGE_NAME" package_name_var) - string(TOLOWER "${${package_name_var}}" package_name) - - # Write copyright information to the package. - configure_file("${CMAKE_CURRENT_LIST_DIR}/copyright" - "${CPACK_TEMPORARY_DIRECTORY}/${comp}/usr/share/doc/${package_name}/copyright" - @ONLY NO_SOURCE_PERMISSIONS) - - # Write changelog to the package. - set(changelog "${CPACK_TEMPORARY_DIRECTORY}/${comp}/usr/share/doc/${package_name}/changelog") - configure_file("${CMAKE_CURRENT_LIST_DIR}/changelog" "${changelog}" - @ONLY NO_SOURCE_PERMISSIONS) - execute_process(COMMAND "${GZIP}" -n9 "${changelog}" COMMAND_ERROR_IS_FATAL ANY) -endforeach () diff --git a/packaging/ubuntu/triggers b/packaging/ubuntu/triggers deleted file mode 100644 index dd8660367847..000000000000 --- a/packaging/ubuntu/triggers +++ /dev/null @@ -1 +0,0 @@ -activate-noawait ldconfig diff --git a/python_bindings/CMakeLists.txt b/python_bindings/CMakeLists.txt index 75fc03d6354d..bcf9e3efcf77 100644 --- a/python_bindings/CMakeLists.txt +++ b/python_bindings/CMakeLists.txt @@ -1,29 +1,117 @@ +cmake_minimum_required(VERSION 3.22...3.23) +project(Halide_Python) + +include(CMakeDependentOption) + ## -# Load Python dependencies, including external pybind11 +# Project options +## + +# Preferred defaults for built-in options +set(CMAKE_CXX_STANDARD 17 CACHE STRING "The minimum C++ standard to use") +option(CMAKE_CXX_STANDARD_REQUIRED "Prevent CMake C++ standard selection decay" ON) +option(CMAKE_CXX_EXTENSIONS "Enable C++ vendor extensions (e.g. GNU)" OFF) + +# Duplicated options from parent project +option(WITH_TESTS "Build tests" ON) +option(WITH_TUTORIALS "Build tutorials" ON) + +# Enable/disable testing +cmake_dependent_option( + WITH_TEST_PYTHON "Build Python tests" ON + WITH_TESTS OFF +) + +# Set the expected (downloaded) version of pybind11 +option(PYBIND11_USE_FETCHCONTENT "Enable to download pybind11 via FetchContent" ON) +set(PYBIND11_VER 2.6.2 CACHE STRING "The pybind11 version to use (or download)") + +## +# Dependencies ## find_package(Python3 REQUIRED COMPONENTS Interpreter Development) -set(PYBIND11_VER 2.6.2) -find_package(pybind11 ${PYBIND11_VER} QUIET) -if (NOT pybind11_FOUND) +if (PYBIND11_USE_FETCHCONTENT) include(FetchContent) - FetchContent_Declare(pybind11 - GIT_REPOSITORY https://github.com/pybind/pybind11.git - GIT_TAG v${PYBIND11_VER}) + FetchContent_Declare( + pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11.git + GIT_TAG v${PYBIND11_VER} + ) FetchContent_MakeAvailable(pybind11) +else () + find_package(pybind11 ${PYBIND11_VER} REQUIRED) +endif () + +find_package(Halide REQUIRED) +if (NOT Halide_ENABLE_RTTI OR NOT Halide_ENABLE_EXCEPTIONS) + message(FATAL_ERROR "Python bindings require RTTI and exceptions to be enabled.") endif () +## +# A helper for creating tests with correct PYTHONPATH and sanitizer preloading +## + +if (Halide_ASAN_ENABLED) + if (NOT DEFINED Halide_Python_ASAN_LIBRARY) + # TODO: this assumes clang-on-Linux, we could be smarter here and check + # CMAKE_CXX_COMPILER_ID to behave differently on GNU, AppleClang, or + # MSVC. + execute_process( + COMMAND ${CMAKE_CXX_COMPILER} "-print-file-name=libclang_rt.asan.so" + OUTPUT_VARIABLE Halide_Python_ASAN_LIBRARY + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + endif () + + set(Halide_Python_ASAN_LIBRARY "${Halide_Python_ASAN_LIBRARY}" + CACHE FILEPATH "Library to preload when running Python tests.") +endif () + +function(add_python_test) + cmake_parse_arguments(ARG "" "FILE;LABEL" "PYTHONPATH;ENVIRONMENT" ${ARGN}) + + list(PREPEND ARG_PYTHONPATH "$/..") + list(TRANSFORM ARG_PYTHONPATH PREPEND "PYTHONPATH=path_list_prepend:") + + list(PREPEND ARG_ENVIRONMENT "HL_TARGET=${Halide_TARGET}") + if (Halide_Python_ASAN_LIBRARY) + if (APPLE) + list(PREPEND ARG_ENVIRONMENT "DYLD_INSERT_LIBRARIES=${Halide_Python_ASAN_LIBRARY}") + else () + list(PREPEND ARG_ENVIRONMENT "LD_PRELOAD=${Halide_Python_ASAN_LIBRARY}") + endif () + endif () + + cmake_path(GET ARG_FILE STEM test_name) + set(test_name "${ARG_LABEL}_${test_name}") + + add_test( + NAME "${test_name}" + COMMAND Python3::Interpreter "$" + ) + set_tests_properties( + "${test_name}" + PROPERTIES + LABELS "python" + ENVIRONMENT "${ARG_ENVIRONMENT}" + ENVIRONMENT_MODIFICATION "${ARG_PYTHONPATH}" + ) +endfunction() + + ## # Add our sources to this sub-tree. ## add_subdirectory(src) -include(stub/CMakeLists.txt) +add_subdirectory(stub) + +if (WITH_TEST_PYTHON) + add_subdirectory(test) +endif () -option(WITH_TEST_PYTHON "Build Python tests" ON) -if (WITH_TESTS AND WITH_TEST_PYTHON) - add_subdirectory(apps) - add_subdirectory(correctness) +if (WITH_TUTORIALS) add_subdirectory(tutorial) endif () diff --git a/python_bindings/Makefile b/python_bindings/Makefile deleted file mode 100644 index bc96a636646d..000000000000 --- a/python_bindings/Makefile +++ /dev/null @@ -1,219 +0,0 @@ -UNAME = $(shell uname) -THIS_MAKEFILE = $(realpath $(filter %Makefile, $(MAKEFILE_LIST))) -ROOT_DIR = $(strip $(shell dirname $(THIS_MAKEFILE))) - -# These are set by Halide's Makefile when built via that path. -HALIDE_PATH ?= $(ROOT_DIR)/.. -HALIDE_DISTRIB_PATH ?= $(HALIDE_PATH)/distrib -BIN ?= $(ROOT_DIR)/bin -PYTHON ?= python3 -TEST_TMP ?= $(BIN)/tmp - -FPIC=-fPIC -ifeq ($(UNAME), Darwin) - SHARED_EXT=dylib -else - SHARED_EXT=so -endif - -ifeq ($(UNAME), Linux) -USE_EXPORT_DYNAMIC=-rdynamic -else -ifeq ($(UNAME), Darwin) -USE_EXPORT_DYNAMIC=-undefined dynamic_lookup -else -USE_EXPORT_DYNAMIC= -endif -endif - -LIBHALIDE ?= $(HALIDE_DISTRIB_PATH)/lib/libHalide.$(SHARED_EXT) - -SUFFIX = $(shell $(PYTHON)-config --extension-suffix) - -# Discover PyBind path from `python3 -m pybind11 --includes` -PYBIND11_CFLAGS = $(shell $(PYTHON) -m pybind11 --includes) - -OPTIMIZE ?= -O3 - -# defining DEBUG + undefining NDEBUG gives extra debug info in PyBind11 -# OPTIMIZE ?= -g -DDEBUG=1 -UNDEBUG - -# Compiling with -fvisibility=hidden saves ~80k on optimized x64 builds. -# It's critical to include -fno-omit-frame-pointer, otherwise introspection can -# break in amusing ways. -CCFLAGS=$(shell $(PYTHON)-config --cflags) $(PYBIND11_CFLAGS) -I $(HALIDE_DISTRIB_PATH)/include -I $(ROOT_DIR) -std=c++17 $(FPIC) -fvisibility=hidden -fvisibility-inlines-hidden -fno-omit-frame-pointer $(OPTIMIZE) $(CXXFLAGS) -# Filter out a pointless warning present in some Python installs -CCFLAGS := $(filter-out -Wstrict-prototypes,$(CCFLAGS)) - -# DON'T link libpython* - leave those symbols to lazily resolve at load time -# Cf. https://github.com/pybind/pybind11/blob/master/docs/compiling.rst#building-manually -LDFLAGS += -lz $(USE_EXPORT_DYNAMIC) -LDFLAGS += -Wl,-rpath,$(dir $(LIBHALIDE)) - -PY_SRCS=$(shell ls $(ROOT_DIR)/src/*.cpp) -PY_OBJS=$(PY_SRCS:$(ROOT_DIR)/src/%.cpp=$(BIN)/src/%.o) - -MODULE=$(BIN)/halide$(SUFFIX) - -$(MODULE): $(PY_OBJS) $(LIBHALIDE) - @echo Building $@... - @mkdir -p $(@D) - @$(CXX) $^ $(LDFLAGS) -shared -o $@ - -# We don't want any of this auto-deleted -.SECONDARY: - -$(BIN)/src/%.o: $(ROOT_DIR)/src/%.cpp - @echo Building $@... - @mkdir -p $(@D) - @$(CXX) $(CCFLAGS) -c $< -o $@ - - -$(BIN)/%_generator.o: $(ROOT_DIR)/correctness/%_generator.cpp $(HALIDE_DISTRIB_PATH)/include/Halide.h - @echo Building $@... - @mkdir -p $(@D) - @$(CXX) $(CCFLAGS) -c $< -o $@ - -$(BIN)/PyStubImpl.o: $(ROOT_DIR)/stub/PyStubImpl.cpp $(HALIDE_DISTRIB_PATH)/include/Halide.h - @echo Building $@... - @mkdir -p $(@D) - @$(CXX) $(CCFLAGS) -c $< -o $@ - -# Produce a Python extension for a C++ generator by compiling PyStub.cpp -# (with HALIDE_PYSTUB_GENERATOR_NAME defined to the Generator's build name), -# and linking with the generator's .o file, PyStubImpl.o, plus the same libHalide -# being used by halide.so. -# -# You can optionally also define HALIDE_PYSTUB_MODULE_NAME if you want the Python -# module name to be something other than the Generator build name. -$(BIN)/%_PyStub.o: $(ROOT_DIR)/stub/PyStub.cpp - @echo Building $@... - @mkdir -p $(@D) - @$(CXX) $(CCFLAGS) -DHALIDE_PYSTUB_GENERATOR_NAME=$* -c $< -o $@ - -$(BIN)/generators/%.so: $(BIN)/%_PyStub.o $(BIN)/PyStubImpl.o $(BIN)/%_generator.o $(LIBHALIDE) - @echo Building $@... - @mkdir -p $(@D) - @$(CXX) $^ $(LDFLAGS) -shared -o $@ - -# Compile the generators: -$(BIN)/%.gen: $(HALIDE_DISTRIB_PATH)/tools/GenGen.cpp $(BIN)/%_generator.o $(LIBHALIDE) - @echo Building $@... - @mkdir -p $(@D) - @$(CXX) $(CCFLAGS) $(LDFLAGS) $^ -o $@ - -# Special generator for generating a runtime: -$(BIN)/runtime.gen: $(HALIDE_DISTRIB_PATH)/tools/GenGen.cpp $(LIBHALIDE) - @echo Building $@... - @mkdir -p $(@D) - @$(CXX) $(CCFLAGS) $(LDFLAGS) $^ -o $@ - -# Generate a runtime: -$(BIN)/runtime.a: $(BIN)/runtime.gen - @echo Building $@... - @mkdir -p $(@D) - @$< -r runtime -o $(BIN) target=host - -# Which target features to use for which test targets. -target_features_addconstant=-no_runtime -target_features_bit=-no_runtime -target_features_user_context=-user_context-no_runtime - -# Make the generator generate a Python extension: -$(BIN)/%.py.cpp $(BIN)/%.a $(BIN)/%.h: $(BIN)/%.gen - @echo Building $@... - @LD_LIBRARY_PATH=$(HALIDE_DISTRIB_PATH)/bin $< \ - -e static_library,c_header,python_extension \ - -g $(notdir $(basename $<)) -o $(BIN) \ - target=host$(target_features_$(notdir $(basename $<))) - -# Compile the generated Python extension(s): -$(BIN)/%.py.o: $(BIN)/%.py.cpp - @echo Building $@... - @$(CXX) -c $(FPIC) $(CCFLAGS) $^ -o $@ - -# Fake up a linker script that will export *just* the PyInit entry -# point we want. (If we don't do this we can have interesting failures -# when loading multiple of these Python extensions in the same space.) -ifeq ($(UNAME), Darwin) -$(BIN)/ext/%.ldscript: - @echo Building $@... - @mkdir -p $(@D) - @echo _PyInit_$* > $@ - -PYEXT_LDSCRIPT_FLAG = -Wl,-exported_symbols_list %LDSCRIPT% -else -# Assume Desktop Linux -$(BIN)/ext/%.ldscript: - @echo Building $@... - @mkdir -p $(@D) - @echo "{" > $@ - @echo " global: PyInit_$*;" >> $@ - @echo " local: *;" >> $@ - @echo "};" >> $@ -PYEXT_LDSCRIPT_FLAG = -Wl,--version-script=%LDSCRIPT% -endif - -# The Python extension of the generator is already in $(BIN), and is named -# the same, so put the Python extension of the function into ext/. -$(BIN)/ext/%.so: $(BIN)/%.py.o $(BIN)/%.a $(BIN)/runtime.a $(BIN)/ext/%.ldscript - @echo Building $@... - @mkdir -p $(@D) - @$(CXX) $(LDFLAGS) $(filter-out $(BIN)/ext/$*.ldscript,$^) -shared $(subst %LDSCRIPT%,$(BIN)/ext/$*.ldscript,$(PYEXT_LDSCRIPT_FLAG)) -o $@ - -test_correctness_addconstant_test: $(BIN)/ext/addconstant.so -test_correctness_bit_test: $(BIN)/ext/bit.so -test_correctness_user_context_test: $(BIN)/ext/user_context.so -test_correctness_pystub: $(BIN)/generators/simplestub.so $(BIN)/generators/complexstub.so - -APPS = $(shell ls $(ROOT_DIR)/apps/*.py) -CORRECTNESS = $(shell ls $(ROOT_DIR)/correctness/*.py) -TUTORIAL = $(shell ls $(ROOT_DIR)/tutorial/*.py) - -.PHONY: test_apps -test_apps: $(APPS:$(ROOT_DIR)/apps/%.py=test_apps_%) - -test_apps_%: $(ROOT_DIR)/apps/%.py $(MODULE) - @echo Testing $*... - @mkdir -p $(TEST_TMP) - @# Send stdout (but not stderr) from these to /dev/null to reduce noise - @cd $(TEST_TMP); PYTHONPATH="$(BIN):$$PYTHONPATH" $(PYTHON) $< >/dev/null - -.PHONY: test_correctness -test_correctness: $(CORRECTNESS:$(ROOT_DIR)/correctness/%.py=test_correctness_%) - -test_correctness_%: $(ROOT_DIR)/correctness/%.py $(MODULE) - @echo Testing $*... - @mkdir -p $(TEST_TMP) - @cd $(TEST_TMP); PYTHONPATH="$(BIN)/ext:$(BIN)/generators:$(BIN):$$PYTHONPATH" $(PYTHON) $< - -.PHONY: test_tutorial -test_tutorial: $(TUTORIAL:$(ROOT_DIR)/tutorial/%.py=test_tutorial_%) - -test_tutorial_%: $(ROOT_DIR)/tutorial/%.py $(MODULE) - @echo Testing $*... - @mkdir -p $(TEST_TMP) - @# Send stdout (but not stderr) from these to /dev/null to reduce noise - @# We need "." in the PYTHONPATH for lesson_10_halide.so. - @cd $(TEST_TMP); PYTHONPATH=".:$(BIN):$$PYTHONPATH" $(PYTHON) $< >/dev/null - -test_tutorial_lesson_10_aot_compilation_run: $(TEST_TMP)/lesson_10_halide.so - -$(TEST_TMP)/lesson_10_halide.so: test_tutorial_lesson_10_aot_compilation_generate - @echo Building $@... - @$(CXX) $(CCFLAGS) $(LDFLAGS) $(FPIC) -shared \ - $(TEST_TMP)/lesson_10_halide.py.cpp \ - $(TEST_TMP)/lesson_10_halide.o \ - -I $(TEST_TMP) -o $@ - -.PHONY: clean -clean: - rm -rf $(BIN) - -.PHONY: test -test: test_correctness test_apps test_tutorial - -# TODO(srj): the python bindings need to be put into the distrib folders; -# this is a hopefully-temporary workaround (https://github.com/halide/Halide/issues/4368) -.PHONY: build_python_bindings -build_python_bindings: $(MODULE) diff --git a/python_bindings/apps/CMakeLists.txt b/python_bindings/apps/CMakeLists.txt deleted file mode 100644 index b94f49e05926..000000000000 --- a/python_bindings/apps/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -set(SCRIPTS - bilateral_grid.py - blur.py - erode.py - interpolate.py - local_laplacian.py) - -foreach (SCRIPT IN LISTS SCRIPTS) - get_filename_component(BASE ${SCRIPT} NAME_WE) - add_test(NAME python_apps_${BASE} - COMMAND Python3::Interpreter "$") - set_tests_properties(python_apps_${BASE} PROPERTIES - LABELS python - ENVIRONMENT "PYTHONPATH=$>;HL_TARGET=${Halide_TARGET}") -endforeach () diff --git a/python_bindings/correctness/CMakeLists.txt b/python_bindings/correctness/CMakeLists.txt deleted file mode 100644 index e3085cca887e..000000000000 --- a/python_bindings/correctness/CMakeLists.txt +++ /dev/null @@ -1,55 +0,0 @@ -set(GENERATORS - complexstub_generator.cpp - simplestub_generator.cpp - ) - -foreach (GEN IN LISTS GENERATORS) - string(REPLACE "_generator.cpp" "" TARGET "${GEN}") - add_generator_python(${TARGET} ${GEN}) -endforeach () - -# Handle addconstant, bit, user_context -add_subdirectory(ext) - -add_library(the_sort_function MODULE the_sort_function.c) -target_link_libraries(the_sort_function PRIVATE Halide::Runtime) - -set(TESTS - addconstant_test.py - atomics.py - autodiff.py - basics.py - bit_test.py - boundary_conditions.py - buffer.py - compile_to.py - division.py - extern.py - float_precision_test.py - iroperator.py - multipass_constraints.py - pystub.py - rdom.py - realize_warnings.py - target.py - tuple_select.py - type.py - user_context_test.py - var.py - ) - -# Use generator expressions to get the true output paths of these files -# CMAKE_CURRENT_BINARY_DIR is incorrect. -make_shell_path(PYTHONPATH - "$" - "$" - "$") - -foreach (TEST IN LISTS TESTS) - get_filename_component(TEST_NAME ${TEST} NAME_WE) - add_test(NAME python_correctness_${TEST_NAME} - COMMAND Python3::Interpreter "$") - set_tests_properties(python_correctness_${TEST_NAME} PROPERTIES - LABELS "python" - ENVIRONMENT "PYTHONPATH=${PYTHONPATH};HL_TARGET=${Halide_TARGET}") -endforeach () diff --git a/python_bindings/correctness/ext/CMakeLists.txt b/python_bindings/correctness/ext/CMakeLists.txt deleted file mode 100644 index 40d41a23c5b5..000000000000 --- a/python_bindings/correctness/ext/CMakeLists.txt +++ /dev/null @@ -1,37 +0,0 @@ -include(TargetExportScript) - -set(FEATURES_user_context user_context) - -foreach (GEN IN ITEMS addconstant bit user_context) - # Create the Halide generator executable - add_executable(${GEN}.gen ../${GEN}_generator.cpp) - target_link_libraries(${GEN}.gen PRIVATE Halide::Generator) - - # Call it to generate the Python extension cpp file - add_halide_library(ext_${GEN} - FROM ${GEN}.gen - GENERATOR ${GEN} - FUNCTION_NAME ${GEN} - PYTHON_EXTENSION ${GEN}_py_cpp - FEATURES ${FEATURES_${GEN}} - TARGETS cmake) - - # Create the module from the generated library and .py.cpp - Python3_add_library(py_${GEN} MODULE ${${GEN}_py_cpp}) - target_link_libraries(py_${GEN} PRIVATE ext_${GEN}) - set_target_properties(py_${GEN} PROPERTIES OUTPUT_NAME ${GEN}) # Python3_add_library adds target info to name. - - # Fake up a linker script that will export *just* the PyInit entry - # point we want. (If we don't do this we can have interesting failures - # when loading multiple of these Python extensions in the same space.) - # - # TODO: How to do this for Windows as well? - configure_file(ext.ldscript.apple.in "${CMAKE_CURRENT_BINARY_DIR}/${GEN}.ldscript.apple") - configure_file(ext.ldscript.linux.in "${CMAKE_CURRENT_BINARY_DIR}/${GEN}.ldscript") - target_export_script( - py_${GEN} - APPLE_LD "${CMAKE_CURRENT_BINARY_DIR}/${GEN}.ldscript.apple" - GNU_LD "${CMAKE_CURRENT_BINARY_DIR}/${GEN}.ldscript" - ) - -endforeach () diff --git a/python_bindings/correctness/ext/ext.ldscript.apple.in b/python_bindings/correctness/ext/ext.ldscript.apple.in deleted file mode 100644 index ea4e8830c7a2..000000000000 --- a/python_bindings/correctness/ext/ext.ldscript.apple.in +++ /dev/null @@ -1 +0,0 @@ -_PyInit_${GEN} diff --git a/python_bindings/correctness/ext/ext.ldscript.linux.in b/python_bindings/correctness/ext/ext.ldscript.linux.in deleted file mode 100644 index a75f043345af..000000000000 --- a/python_bindings/correctness/ext/ext.ldscript.linux.in +++ /dev/null @@ -1,4 +0,0 @@ -{ -global: PyInit_${GEN}; -local: *; -}; diff --git a/python_bindings/pyproject.toml b/python_bindings/pyproject.toml new file mode 100644 index 000000000000..e3a89e6d5ef2 --- /dev/null +++ b/python_bindings/pyproject.toml @@ -0,0 +1,10 @@ +[build-system] +requires = [ + "setuptools>=42", + "wheel", + "scikit-build", + "pybind11==2.6.2", + "cmake>=3.22", + "ninja; platform_system!='Windows'" +] +build-backend = "setuptools.build_meta" diff --git a/python_bindings/readme.md b/python_bindings/readme.md index e285a1716cab..3192f12834a2 100644 --- a/python_bindings/readme.md +++ b/python_bindings/readme.md @@ -13,12 +13,12 @@ with some differences where the C++ idiom is either inappropriate or impossible: as `[]` is not syntactically acceptable in Python. - Some classes in the Halide API aren't provided because they are 'wrapped' with standard Python idioms: - - `Halide::Tuple` doesn't exist in the Python bindings; an ordinary Python - tuple of `Halide::Expr` is used instead. - - `Halide::Realization` doesn't exist in the Python bindings; an ordinary - Python tuple of `Halide::Buffer` is used instead. - - `Halide::Error` and friends don't exist; standard Python error handling is - used instead. + - `Halide::Tuple` doesn't exist in the Python bindings; an ordinary Python + tuple of `Halide::Expr` is used instead. + - `Halide::Realization` doesn't exist in the Python bindings; an ordinary + Python tuple of `Halide::Buffer` is used instead. + - `Halide::Error` and friends don't exist; standard Python error handling is + used instead. - static and instance method overloads with the same name in the same class aren't allowed, so some convenience methods are missing from `Halide::Var` - Templated types (notably `Halide::Buffer<>` and `Halide::Param<>`) aren't @@ -35,7 +35,9 @@ with some differences where the C++ idiom is either inappropriate or impossible: entirely for now. - `Func::in` becomes `Func.in_` because `in` is a Python keyword. - `Func::async` becomes `Func.async_` because `async` is a Python keyword. -- The `not` keyword cannot be used to negate boolean Halide expressions. Instead, the `logical_not` function can be used and is equivalent to using `operator!` in C++. +- The `not` keyword cannot be used to negate boolean Halide expressions. + Instead, the `logical_not` function can be used and is equivalent to + using `operator!` in C++. ## Enhancements to the C++ API @@ -47,8 +49,7 @@ with some differences where the C++ idiom is either inappropriate or impossible: ## Prerequisites The bindings (and demonstration applications) should work well for Python 3.4 -(or higher), on Linux and OSX platforms. Windows support is experimental, and -available through the CMake build. +(or higher), on Linux and OSX platforms. Windows support is experimental. #### Python requirements: @@ -57,18 +58,15 @@ The best way to get set up is to use a virtual environment: ```console $ python3 -m venv venv $ . venv/bin/activate -$ pip install -r requirements.txt +$ python3 -m pip install -U setuptools wheel +$ python3 -m pip install -r requirements.txt ``` -#### C++ requirements: - -- Halide compiled to a distribution (e.g. `make distrib` or similar), with the - `HALIDE_DISTRIB_PATH` env var pointing to it -- If using CMake, simply set `-DWITH_PYTHON_BINDINGS=ON` from the main build. - ## Compilation instructions -Build using: `make` +Build as part of the CMake build with `-DWITH_PYTHON_BINDINGS=ON`. Note that +this requires both Halide and LLVM to be built with RTTI and exceptions +**enabled**, which is not the default for LLVM. ## Documentation and Examples @@ -78,11 +76,14 @@ The Python API reflects directly the Check out the code for the example applications in the `apps/` and `tutorial/` subdirectory. -You can run them as a batch via `make test_apps` or `make test_tutorial`. +The tests run as part of the standard CTest infrastructure and are labeled with +the `python` label. You can run the Python tests specifically by running: + +``` +$ ctest -L python +``` -To run these examples, make sure the `PYTHONPATH` environment variable points to -your build directory (e.g. -`export PYTHONPATH=halide_source/python_bindings/bin:$PYTHONPATH`). +From the Halide build directory. ## License diff --git a/python_bindings/setup.py b/python_bindings/setup.py new file mode 100644 index 000000000000..1a711ebc015e --- /dev/null +++ b/python_bindings/setup.py @@ -0,0 +1,53 @@ +from skbuild import setup, cmaker, utils +from setuptools import find_packages +from pathlib import Path +import pybind11 +from tempfile import TemporaryDirectory as mkdtemp_ctx +import textwrap + + +def get_version(): + """ + Builds a dummy project that prints the found Halide version. The "version" + of these Halide bindings is whatever version of Halide they're building + against. + """ + + cmakelists_txt = textwrap.dedent( + """ + cmake_minimum_required(VERSION 3.22) + project(dummy) + find_package(Halide REQUIRED) + file(WRITE halide_version.txt "${Halide_VERSION}") + """ + ) + + with mkdtemp_ctx() as srcdir, mkdtemp_ctx() as dstdir: + src, dst = Path(srcdir), Path(dstdir) + (src / "CMakeLists.txt").write_text(cmakelists_txt) + with utils.push_dir(dst): + cmkr = cmaker.CMaker() + cmkr.configure(cmake_source_dir=src, clargs=("--no-warn-unused-cli",)) + version = (src / "halide_version.txt").read_text().strip() + return version + + +setup( + name="halide", + version=get_version(), + author="The Halide team", + author_email="", + description="", + long_description=Path("readme.md").read_text(), + python_requires=">=3.6", + packages=find_packages(where="src"), + package_dir={"": "src"}, + cmake_args=[ + f"-Dpybind11_ROOT={pybind11.get_cmake_dir()}", + "-DCMAKE_REQUIRE_FIND_PACKAGE_pybind11=YES", + "-DHalide_INSTALL_PYTHONDIR=src", + "-DCMAKE_INSTALL_RPATH=$ORIGIN", + "-DHalide_Python_INSTALL_IMPORTED_DEPS=ON", + "--no-warn-unused-cli", + ], +) diff --git a/python_bindings/src/CMakeLists.txt b/python_bindings/src/CMakeLists.txt index f8a57b61e74f..aa44c7e36f7f 100644 --- a/python_bindings/src/CMakeLists.txt +++ b/python_bindings/src/CMakeLists.txt @@ -1,50 +1 @@ -set(SOURCES - PyArgument.cpp - PyBoundaryConditions.cpp - PyBuffer.cpp - PyConciseCasts.cpp - PyDerivative.cpp - PyEnums.cpp - PyError.cpp - PyExpr.cpp - PyExternFuncArgument.cpp - PyFunc.cpp - PyFuncRef.cpp - PyHalide.cpp - PyImageParam.cpp - PyInlineReductions.cpp - PyIROperator.cpp - PyLambda.cpp - PyLoopLevel.cpp - PyMachineParams.cpp - PyModule.cpp - PyParam.cpp - PyPipeline.cpp - PyRDom.cpp - PyStage.cpp - PyTarget.cpp - PyTuple.cpp - PyType.cpp - PyVar.cpp - PyVarOrRVar.cpp - PyEvictionKey.cpp - ) - -pybind11_add_module(Halide_Python MODULE ${SOURCES}) -add_library(Halide::Python ALIAS Halide_Python) -set_target_properties(Halide_Python - PROPERTIES - LIBRARY_OUTPUT_NAME halide - EXPORT_NAME Python) -target_link_libraries(Halide_Python PRIVATE Halide::Halide) - -if (WIN32 AND BUILD_SHARED_LIBS) - # There's precious little information about why Python only sometimes prevents DLLs from loading from the PATH on Windows. - # This workaround places a copy of Halide.dll next to our Python module. - # Ref: https://stackoverflow.com/questions/59860465/pybind11-importerror-dll-not-found-when-trying-to-import-pyd-in-python-int - # Ref: https://bugs.python.org/issue36085 - # Ref: https://docs.python.org/3/whatsnew/3.8.html#bpo-36085-whatsnew - add_custom_command(TARGET Halide_Python POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $ - VERBATIM) -endif () +add_subdirectory(halide) diff --git a/python_bindings/src/PyBuffer.h b/python_bindings/src/PyBuffer.h deleted file mode 100644 index b4b2502ea748..000000000000 --- a/python_bindings/src/PyBuffer.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef HALIDE_PYTHON_BINDINGS_PYBUFFER_H -#define HALIDE_PYTHON_BINDINGS_PYBUFFER_H - -#include "PyHalide.h" - -namespace Halide { -namespace PythonBindings { - -void define_buffer(py::module &m); - -} // namespace PythonBindings -} // namespace Halide - -#endif // HALIDE_PYTHON_BINDINGS_PYBUFFER_H diff --git a/python_bindings/src/halide/CMakeLists.txt b/python_bindings/src/halide/CMakeLists.txt new file mode 100644 index 000000000000..6fc85e63fca3 --- /dev/null +++ b/python_bindings/src/halide/CMakeLists.txt @@ -0,0 +1,165 @@ +set(native_sources + PyArgument.cpp + PyBoundaryConditions.cpp + PyBuffer.cpp + PyCallable.cpp + PyConciseCasts.cpp + PyDerivative.cpp + PyEnums.cpp + PyError.cpp + PyExpr.cpp + PyExternFuncArgument.cpp + PyFunc.cpp + PyFuncRef.cpp + PyHalide.cpp + PyImageParam.cpp + PyInlineReductions.cpp + PyIROperator.cpp + PyLambda.cpp + PyLoopLevel.cpp + PyMachineParams.cpp + PyModule.cpp + PyParam.cpp + PyPipeline.cpp + PyRDom.cpp + PyStage.cpp + PyTarget.cpp + PyTuple.cpp + PyType.cpp + PyVar.cpp + PyVarOrRVar.cpp + PyEvictionKey.cpp + ) +list(TRANSFORM native_sources PREPEND "halide_/") + +set(python_sources + __init__.py + ) + +# It is technically still possible for a user to override the LIBRARY_OUTPUT_DIRECTORY by setting +# CMAKE_LIBRARY_OUTPUT_DIRECTORY_, but they do so at their own peril. If a user needs to +# do this, they should use the CMAKE_PROJECT_Halide_Python_INCLUDE_BEFORE variable to override it +# just for this project, rather than globally, and they should ensure that the last path component +# is `halide`. Otherwise, the tests will break. +pybind11_add_module(Halide_Python MODULE ${native_sources}) +add_library(Halide::Python ALIAS Halide_Python) +set_target_properties( + Halide_Python + PROPERTIES + LIBRARY_OUTPUT_NAME halide_ + LIBRARY_OUTPUT_DIRECTORY "$/halide" + EXPORT_NAME Python +) +target_link_libraries(Halide_Python PRIVATE Halide::Halide) + +# TODO: There's precious little information about why Python only sometimes prevents DLLs from loading from the PATH +# on Windows. This workaround places a copy of Halide.dll (and any other dependencies) next to our Python module. +# Ref: https://stackoverflow.com/questions/59860465/pybind11-importerror-dll-not-found-when-trying-to-import-pyd-in-python-int +# Ref: https://bugs.python.org/issue36085 +# Ref: https://docs.python.org/3/whatsnew/3.8.html#bpo-36085-whatsnew +# TODO: copying a dummy file here works around a CMake limitation. The issue is that if $ is +# empty, then copy_if_different errors out, thinking it doesn't have enough arguments. +# Ref: https://gitlab.kitware.com/cmake/cmake/-/issues/23543 +set(dummy_file "${CMAKE_CURRENT_BINARY_DIR}/.dummy_file") +file(TOUCH "${dummy_file}") +add_custom_command( + TARGET Halide_Python POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${dummy_file}" $ $ + COMMAND_EXPAND_LISTS + VERBATIM +) + +# Copy our Python source files over so that we have a valid package in the binary directory. +# TODO: When upgrading to CMake 3.23 or beyond, investigate the FILE_SET feature. +set(build_tree_pys "") +foreach (pysrc IN LISTS python_sources) + # TODO: CMake 3.22 still doesn't allow target-dependent genex in OUTPUT, but we can hack around this using a stamp + # file. Fix this hack up if and when they ever improve this feature. + set(stamp_file "${CMAKE_CURRENT_BINARY_DIR}/.${pysrc}.stamp") + add_custom_command( + OUTPUT "${stamp_file}" + COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_CURRENT_SOURCE_DIR}/${pysrc}" "$/${pysrc}" + COMMAND ${CMAKE_COMMAND} -E touch "${stamp_file}" + DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/${pysrc}" + VERBATIM + ) + list(APPEND build_tree_pys "${stamp_file}") +endforeach () +add_custom_target(Halide_Python_sources ALL DEPENDS ${build_tree_pys}) +add_dependencies(Halide_Python Halide_Python_sources) + +## +# Packaging +## + +include(CMakeDependentOption) +include(GNUInstallDirs) + +set(Halide_INSTALL_PYTHONDIR "${CMAKE_INSTALL_LIBDIR}/python3/site-packages" + CACHE STRING "Path to the Python site-packages folder") + +install(DIRECTORY "$/" + DESTINATION "${Halide_INSTALL_PYTHONDIR}/halide" + COMPONENT Halide_Python + FILES_MATCHING + PATTERN "*.py" + PATTERN "*/halide_" EXCLUDE + PATTERN "*/CMakeFiles" EXCLUDE + PATTERN "*/__pycache__" EXCLUDE) + +install(TARGETS Halide_Python + LIBRARY DESTINATION "${Halide_INSTALL_PYTHONDIR}/halide" + COMPONENT Halide_Python) + +get_property(halide_is_imported TARGET Halide::Halide PROPERTY IMPORTED) +get_property(halide_type TARGET Halide::Halide PROPERTY TYPE) +cmake_dependent_option( + Halide_Python_INSTALL_IMPORTED_DEPS "" OFF + "halide_is_imported;halide_type STREQUAL \"SHARED_LIBRARY\"" OFF +) + +if (Halide_Python_INSTALL_IMPORTED_DEPS) + # The following might be a bit confusing, but installing both libHalide + # and its SONAME symbolic link causes the following bad behavior: + # 1. CMake does the right thing and installs libHalide.so.X.Y.Z + # (TARGET_FILE) as a real file and libHalide.so.X + # (TARGET_SONAME_FILE_NAME) as a symbolic link to the former. + # 2. Setuptools dutifully packs both of these into a Python wheel, which + # is a structured zip file. Zip files do not support symbolic links. + # Thus, two independent copies of libHalide are inserted, bloating the + # package. + # The Python module (on Unix systems) links to the SONAME file, and + # installing the symbolic link directly results in a broken link. Hence, + # the renaming dance here. + + if (NOT MSVC) + set(rename_arg RENAME "$") + else () + # DLL systems do not have sonames. + set(rename_arg "") + endif () + + # TODO: when we upgrade to CMake 3.22, replace with RUNTIME_DEPENDENCY_SET? + install(FILES "$" + DESTINATION "${Halide_INSTALL_PYTHONDIR}/halide" + COMPONENT Halide_Python + ${rename_arg}) +endif () + +if ( + NOT CMAKE_INSTALL_RPATH # Honor user overrides + AND NOT halide_is_imported # Imported Halide means user is responsible for RPATH + AND halide_type STREQUAL "SHARED_LIBRARY" # No need to set RPATH if statically linked +) + if (APPLE) + set(rbase @loader_path) + else () + set(rbase $ORIGIN) + endif () + + file(RELATIVE_PATH lib_dir + "${CMAKE_CURRENT_BINARY_DIR}/${Halide_INSTALL_PYTHONDIR}/halide" + "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_INSTALL_LIBDIR}") + + set_target_properties(Halide_Python PROPERTIES INSTALL_RPATH "${rbase}/${lib_dir}") +endif () diff --git a/python_bindings/src/halide/__init__.py b/python_bindings/src/halide/__init__.py new file mode 100644 index 000000000000..20c9f793e327 --- /dev/null +++ b/python_bindings/src/halide/__init__.py @@ -0,0 +1,2 @@ +from .halide_ import * +from .halide_ import _, _1, _2, _3, _4, _5, _6, _7, _8, _9 diff --git a/python_bindings/src/PyArgument.cpp b/python_bindings/src/halide/halide_/PyArgument.cpp similarity index 100% rename from python_bindings/src/PyArgument.cpp rename to python_bindings/src/halide/halide_/PyArgument.cpp diff --git a/python_bindings/src/PyArgument.h b/python_bindings/src/halide/halide_/PyArgument.h similarity index 100% rename from python_bindings/src/PyArgument.h rename to python_bindings/src/halide/halide_/PyArgument.h diff --git a/python_bindings/src/PyBinaryOperators.h b/python_bindings/src/halide/halide_/PyBinaryOperators.h similarity index 98% rename from python_bindings/src/PyBinaryOperators.h rename to python_bindings/src/halide/halide_/PyBinaryOperators.h index ba3536f0b735..137a30dd4d5f 100644 --- a/python_bindings/src/PyBinaryOperators.h +++ b/python_bindings/src/halide/halide_/PyBinaryOperators.h @@ -61,6 +61,7 @@ HANDLE_SCALAR_TYPE(int8_t) HANDLE_SCALAR_TYPE(int16_t) HANDLE_SCALAR_TYPE(int32_t) HANDLE_SCALAR_TYPE(int64_t) +// HANDLE_SCALAR_TYPE(bfloat16_t) TODO: https://github.com/halide/Halide/issues/6849 HANDLE_SCALAR_TYPE(float16_t) HANDLE_SCALAR_TYPE(float) HANDLE_SCALAR_TYPE(double) diff --git a/python_bindings/src/PyBoundaryConditions.cpp b/python_bindings/src/halide/halide_/PyBoundaryConditions.cpp similarity index 100% rename from python_bindings/src/PyBoundaryConditions.cpp rename to python_bindings/src/halide/halide_/PyBoundaryConditions.cpp diff --git a/python_bindings/src/PyBoundaryConditions.h b/python_bindings/src/halide/halide_/PyBoundaryConditions.h similarity index 100% rename from python_bindings/src/PyBoundaryConditions.h rename to python_bindings/src/halide/halide_/PyBoundaryConditions.h diff --git a/python_bindings/src/PyBuffer.cpp b/python_bindings/src/halide/halide_/PyBuffer.cpp similarity index 89% rename from python_bindings/src/PyBuffer.cpp rename to python_bindings/src/halide/halide_/PyBuffer.cpp index 65c2cd18b570..7fb2ff84142d 100644 --- a/python_bindings/src/PyBuffer.cpp +++ b/python_bindings/src/halide/halide_/PyBuffer.cpp @@ -57,6 +57,12 @@ inline float16_t value_cast(const py::object &value) { return float16_t(value.cast()); } +// TODO: https://github.com/halide/Halide/issues/6849 +// template<> +// inline bfloat16_t value_cast(const py::object &value) { +// return bfloat16_t(value.cast()); +// } + template inline std::string format_descriptor() { return py::format_descriptor::format(); @@ -67,6 +73,12 @@ inline std::string format_descriptor() { return "e"; } +// TODO: https://github.com/halide/Halide/issues/6849 +// template<> +// inline std::string format_descriptor() { +// return there-is-no-python-buffer-format-descriptor-for-bfloat16; +// } + void call_fill(Buffer<> &b, const py::object &value) { #define HANDLE_BUFFER_TYPE(TYPE) \ @@ -84,6 +96,8 @@ void call_fill(Buffer<> &b, const py::object &value) { HANDLE_BUFFER_TYPE(int16_t) HANDLE_BUFFER_TYPE(int32_t) HANDLE_BUFFER_TYPE(int64_t) + // TODO: https://github.com/halide/Halide/issues/6849 + // HANDLE_BUFFER_TYPE(bfloat16_t) HANDLE_BUFFER_TYPE(float16_t) HANDLE_BUFFER_TYPE(float) HANDLE_BUFFER_TYPE(double) @@ -109,6 +123,8 @@ bool call_all_equal(Buffer<> &b, const py::object &value) { HANDLE_BUFFER_TYPE(int16_t) HANDLE_BUFFER_TYPE(int32_t) HANDLE_BUFFER_TYPE(int64_t) + // TODO: https://github.com/halide/Halide/issues/6849 + // HANDLE_BUFFER_TYPE(bfloat16_t) HANDLE_BUFFER_TYPE(float16_t) HANDLE_BUFFER_TYPE(float) HANDLE_BUFFER_TYPE(double) @@ -133,6 +149,8 @@ std::string type_to_format_descriptor(const Type &type) { HANDLE_BUFFER_TYPE(int16_t) HANDLE_BUFFER_TYPE(int32_t) HANDLE_BUFFER_TYPE(int64_t) + // TODO: https://github.com/halide/Halide/issues/6849 + // HANDLE_BUFFER_TYPE(bfloat16_t) HANDLE_BUFFER_TYPE(float16_t) HANDLE_BUFFER_TYPE(float) HANDLE_BUFFER_TYPE(double) @@ -143,6 +161,8 @@ std::string type_to_format_descriptor(const Type &type) { return std::string(); } +} // namespace + Type format_descriptor_to_type(const std::string &fd) { #define HANDLE_BUFFER_TYPE(TYPE) \ @@ -157,6 +177,8 @@ Type format_descriptor_to_type(const std::string &fd) { HANDLE_BUFFER_TYPE(int16_t) HANDLE_BUFFER_TYPE(int32_t) HANDLE_BUFFER_TYPE(int64_t) + // TODO: https://github.com/halide/Halide/issues/6849 + // HANDLE_BUFFER_TYPE(bfloat16_t) HANDLE_BUFFER_TYPE(float16_t) HANDLE_BUFFER_TYPE(float) HANDLE_BUFFER_TYPE(double) @@ -173,6 +195,8 @@ Type format_descriptor_to_type(const std::string &fd) { return Type(); } +namespace { + py::object buffer_getitem_operator(Buffer<> &buf, const std::vector &pos) { if ((size_t)pos.size() != (size_t)buf.dimensions()) { throw py::value_error("Incorrect number of dimensions."); @@ -192,6 +216,8 @@ py::object buffer_getitem_operator(Buffer<> &buf, const std::vector &pos) { HANDLE_BUFFER_TYPE(int16_t) HANDLE_BUFFER_TYPE(int32_t) HANDLE_BUFFER_TYPE(int64_t) + // TODO: https://github.com/halide/Halide/issues/6849 + // HANDLE_BUFFER_TYPE(bfloat16_t) HANDLE_BUFFER_TYPE(float16_t) HANDLE_BUFFER_TYPE(float) HANDLE_BUFFER_TYPE(double) @@ -220,6 +246,8 @@ py::object buffer_setitem_operator(Buffer<> &buf, const std::vector &pos, c HANDLE_BUFFER_TYPE(int16_t) HANDLE_BUFFER_TYPE(int32_t) HANDLE_BUFFER_TYPE(int64_t) + // TODO: https://github.com/halide/Halide/issues/6849 + // HANDLE_BUFFER_TYPE(bfloat16_t) HANDLE_BUFFER_TYPE(float16_t) HANDLE_BUFFER_TYPE(float) HANDLE_BUFFER_TYPE(double) @@ -236,26 +264,8 @@ py::object buffer_setitem_operator(Buffer<> &buf, const std::vector &pos, c class PyBuffer : public Buffer<> { py::buffer_info info; - static std::vector make_dim_vec(const py::buffer_info &info) { - const Type t = format_descriptor_to_type(info.format); - std::vector dims; - dims.reserve(info.ndim); - for (int i = 0; i < info.ndim; i++) { - if (INT_MAX < info.shape[i] || INT_MAX < (info.strides[i] / t.bytes())) { - throw py::value_error("Out of range arguments to make_dim_vec."); - } - dims.emplace_back(0, (int32_t)info.shape[i], (int32_t)(info.strides[i] / t.bytes())); - } - return dims; - } - PyBuffer(py::buffer_info &&info, const std::string &name) - : Buffer<>( - format_descriptor_to_type(info.format), - info.ptr, - (int)info.ndim, - make_dim_vec(info).data(), - name), + : Buffer<>(pybufferinfo_to_halidebuffer(info), name), info(std::move(info)) { } @@ -313,10 +323,10 @@ void define_buffer(py::module &m) { const int d = b.dimensions(); const int bytes = b.type().bytes(); - std::vector shape, strides; + std::vector shape, strides; for (int i = 0; i < d; i++) { - shape.push_back((ssize_t)b.raw_buffer()->dim[i].extent); - strides.push_back((ssize_t)(b.raw_buffer()->dim[i].stride * bytes)); + shape.push_back((Py_ssize_t)b.raw_buffer()->dim[i].extent); + strides.push_back((Py_ssize_t)(b.raw_buffer()->dim[i].stride * bytes)); } return py::buffer_info( @@ -363,23 +373,23 @@ void define_buffer(py::module &m) { .def("set_name", &Buffer<>::set_name) .def("name", &Buffer<>::name) - .def("same_as", (bool (Buffer<>::*)(const Buffer<> &other) const) & Buffer<>::same_as, py::arg("other")) + .def("same_as", (bool(Buffer<>::*)(const Buffer<> &other) const) & Buffer<>::same_as, py::arg("other")) .def("defined", &Buffer<>::defined) .def("type", &Buffer<>::type) - .def("channels", (int (Buffer<>::*)() const) & Buffer<>::channels) - .def("dimensions", (int (Buffer<>::*)() const) & Buffer<>::dimensions) - .def("width", (int (Buffer<>::*)() const) & Buffer<>::width) - .def("height", (int (Buffer<>::*)() const) & Buffer<>::height) - .def("top", (int (Buffer<>::*)() const) & Buffer<>::top) - .def("bottom", (int (Buffer<>::*)() const) & Buffer<>::bottom) - .def("left", (int (Buffer<>::*)() const) & Buffer<>::left) - .def("right", (int (Buffer<>::*)() const) & Buffer<>::right) + .def("channels", (int(Buffer<>::*)() const) & Buffer<>::channels) + .def("dimensions", (int(Buffer<>::*)() const) & Buffer<>::dimensions) + .def("width", (int(Buffer<>::*)() const) & Buffer<>::width) + .def("height", (int(Buffer<>::*)() const) & Buffer<>::height) + .def("top", (int(Buffer<>::*)() const) & Buffer<>::top) + .def("bottom", (int(Buffer<>::*)() const) & Buffer<>::bottom) + .def("left", (int(Buffer<>::*)() const) & Buffer<>::left) + .def("right", (int(Buffer<>::*)() const) & Buffer<>::right) .def("number_of_elements", (size_t(Buffer<>::*)() const) & Buffer<>::number_of_elements) .def("size_in_bytes", (size_t(Buffer<>::*)() const) & Buffer<>::size_in_bytes) - .def("has_device_allocation", (bool (Buffer<>::*)() const) & Buffer<>::has_device_allocation) - .def("host_dirty", (bool (Buffer<>::*)() const) & Buffer<>::host_dirty) - .def("device_dirty", (bool (Buffer<>::*)() const) & Buffer<>::device_dirty) + .def("has_device_allocation", (bool(Buffer<>::*)() const) & Buffer<>::has_device_allocation) + .def("host_dirty", (bool(Buffer<>::*)() const) & Buffer<>::host_dirty) + .def("device_dirty", (bool(Buffer<>::*)() const) & Buffer<>::device_dirty) .def( "set_host_dirty", [](Buffer<> &b, bool dirty) -> void { @@ -395,13 +405,13 @@ void define_buffer(py::module &m) { .def("copy", &Buffer<>::copy) .def("copy_from", &Buffer<>::copy_from::AnyDims>) - .def("add_dimension", (void (Buffer<>::*)()) & Buffer<>::add_dimension) + .def("add_dimension", (void(Buffer<>::*)()) & Buffer<>::add_dimension) .def("allocate", [](Buffer<> &b) -> void { b.allocate(nullptr, nullptr); }) - .def("deallocate", (void (Buffer<>::*)()) & Buffer<>::deallocate) - .def("device_deallocate", (void (Buffer<>::*)()) & Buffer<>::device_deallocate) + .def("deallocate", (void(Buffer<>::*)()) & Buffer<>::deallocate) + .def("device_deallocate", (void(Buffer<>::*)()) & Buffer<>::device_deallocate) .def( "crop", [](Buffer<> &b, int d, int min, int extent) -> void { diff --git a/python_bindings/src/halide/halide_/PyBuffer.h b/python_bindings/src/halide/halide_/PyBuffer.h new file mode 100644 index 000000000000..8b108c4e2abc --- /dev/null +++ b/python_bindings/src/halide/halide_/PyBuffer.h @@ -0,0 +1,39 @@ +#ifndef HALIDE_PYTHON_BINDINGS_PYBUFFER_H +#define HALIDE_PYTHON_BINDINGS_PYBUFFER_H + +#include "PyHalide.h" + +namespace Halide { +namespace PythonBindings { + +void define_buffer(py::module &m); + +Type format_descriptor_to_type(const std::string &fd); + +template +Halide::Runtime::Buffer pybufferinfo_to_halidebuffer(const py::buffer_info &info) { + const Type t = format_descriptor_to_type(info.format); + halide_dimension_t *dims = (halide_dimension_t *)alloca(info.ndim * sizeof(halide_dimension_t)); + _halide_user_assert(dims); + for (int i = 0; i < info.ndim; i++) { + if (INT_MAX < info.shape[i] || INT_MAX < (info.strides[i] / t.bytes())) { + throw py::value_error("Out of range dimensions in buffer conversion."); + } + dims[i] = {0, (int32_t)info.shape[i], (int32_t)(info.strides[i] / t.bytes())}; + } + return Halide::Runtime::Buffer(t, info.ptr, (int)info.ndim, dims); +} + +template +Halide::Runtime::Buffer pybuffer_to_halidebuffer(const py::buffer &pyb, bool writable) { + return pybufferinfo_to_halidebuffer(pyb.request(writable)); +} + +} // namespace PythonBindings +} // namespace Halide + +#endif // HALIDE_PYTHON_BINDINGS_PYBUFFER_H diff --git a/python_bindings/src/halide/halide_/PyCallable.cpp b/python_bindings/src/halide/halide_/PyCallable.cpp new file mode 100644 index 000000000000..52bdd9ae5559 --- /dev/null +++ b/python_bindings/src/halide/halide_/PyCallable.cpp @@ -0,0 +1,193 @@ +#include "PyCallable.h" + +#include "PyBuffer.h" + +#define TYPED_ALLOCA(TYPE, COUNT) ((TYPE *)alloca(sizeof(TYPE) * (COUNT))) + +namespace Halide { +namespace PythonBindings { + +namespace { + +// We avoid extra dynamic memory allocations for Buffers by preallocating enough +// space for 8 (rather than the default of 4) -- more is ok but slower, and > 8 +// seems pretty unlikely for real world code. +constexpr int MaxFastDimensions = 8; +using HalideBuffer = Halide::Runtime::Buffer; + +struct HBufArray { + const size_t count; + HalideBuffer *buffers; + + explicit HBufArray(size_t count, void *storage) + : count(count), buffers((HalideBuffer *)storage) { + _halide_user_assert(storage); + for (size_t i = 0; i < count; i++) { + // placement new to get the ctors run + new (&buffers[i]) HalideBuffer; + } + } + + ~HBufArray() { + for (size_t i = 0; i < count; i++) { + // Manually call the dtors + buffers[i].~HalideBuffer(); + } + } +}; + +template +T cast_to(const py::handle &h) { + // We want to ensure that the error thrown is one that will be translated + // to `hl.HalideError` in Python. + try { + return h.cast(); + } catch (const std::exception &e) { + throw Halide::Error(e.what()); + } +} + +} // namespace + +class PyCallable { +public: + // TODO: support kwargs here too. + static void call_impl(Callable &c, const py::args &args, const py::kwargs &kwargs) { + const size_t argc = c.arguments().size(); + _halide_user_assert(argc > 0); + const Argument *c_args = c.arguments().data(); + + // We want to keep call overhead as low as possible here, + // so use alloca (rather than e.g. std::vector) for short-term + // small allocations. + const void **argv = TYPED_ALLOCA(const void *, argc); + halide_scalar_value_t *scalar_storage = TYPED_ALLOCA(halide_scalar_value_t, argc); + HBufArray buffers(argc, TYPED_ALLOCA(HalideBuffer, argc)); + Callable::QuickCallCheckInfo *cci = TYPED_ALLOCA(Callable::QuickCallCheckInfo, argc); + + _halide_user_assert(argv && scalar_storage && buffers.buffers && cci) << "alloca failure"; + + // Clear argv to all zero so we can use it to validate that all fields are + // set properly when using kwargs -- a well-formed call will never have any + // of the fields left null, nor any set twice. (The other alloca stuff can + // keep garbage in unused parts.) + memset(argv, 0, sizeof(const void *) * argc); + + _halide_user_assert(args.size() <= argc - 1) + << "Expected at most " << (argc - 1) << " positional arguments, but saw " << args.size() << "."; + + // args + JITUserContext empty_jit_user_context; + scalar_storage[0].u.u64 = (uintptr_t)&empty_jit_user_context; + argv[0] = &scalar_storage[0]; + cci[0] = Callable::make_ucon_qcci(); + + const auto define_one_arg = [&argv, &scalar_storage, &buffers, &cci](const Argument &c_arg, py::handle value, size_t slot) { + if (c_arg.is_buffer()) { + // If the argument is already a Halide Buffer of some sort, + // skip pybuffer_to_halidebuffer entirely, since the latter requires + // a non-null host ptr, but we might want such a buffer for bounds inference, + // and we don't need the intermediate HalideBuffer wrapper anyway. + if (py::isinstance>(value)) { + auto b = cast_to>(value); + argv[slot] = b.raw_buffer(); + } else { + const bool writable = c_arg.is_output(); + buffers.buffers[slot] = pybuffer_to_halidebuffer(cast_to(value), writable); + argv[slot] = buffers.buffers[slot].raw_buffer(); + } + cci[slot] = Callable::make_buffer_qcci(); + } else { + argv[slot] = &scalar_storage[slot]; + + // clang-format off + + #define HALIDE_HANDLE_TYPE_DISPATCH(CODE, BITS, TYPE, FIELD) \ + case halide_type_t(CODE, BITS).as_u32(): \ + scalar_storage[slot].u.FIELD = cast_to(value); \ + cci[slot] = Callable::make_scalar_qcci(halide_type_t(CODE, BITS)); \ + break; + + switch (((halide_type_t)c_arg.type).element_of().as_u32()) { + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_float, 32, float, f32) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_float, 64, double, f64) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 8, int8_t, i8) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 16, int16_t, i16) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 32, int32_t, i32) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 64, int64_t, i64) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 1, bool, b) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 8, uint8_t, u8) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 16, uint16_t, u16) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 32, uint32_t, u32) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 64, uint64_t, u64) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_handle, 64, uint64_t, u64) // Handle types are always uint64, regardless of pointer size + default: + _halide_user_assert(0) << "Unsupported type in Callable argument list: " << c_arg.type << "\n"; + } + + #undef HALIDE_HANDLE_TYPE_DISPATCH + + // clang-format on + } + }; + + for (size_t i = 0; i < args.size(); i++) { + const auto &c_arg = c_args[i + 1]; // c_args[0] is the JITUserContext + const size_t slot = i + 1; + define_one_arg(c_arg, args[i], slot); + } + + if (!kwargs.empty()) { + // Also process kwargs. + for (auto kw : kwargs) { + const std::string name = cast_to(kw.first); + const py::handle value = kw.second; + + // Find the slot with this name. + // Skip element 0, since that's always JITUserContext and not visible in Python. + // + // TODO: should we build an inverse map here? For small numbers + // of arguments a linear search is probably faster. + for (size_t slot = 1; slot < argc; slot++) { + const auto &c_arg = c_args[slot]; + if (c_arg.name == name) { + _halide_user_assert(argv[slot] == nullptr) << "Argument " << name << " specified multiple times."; + define_one_arg(c_arg, value, slot); + goto found_kw_arg; + } + } + _halide_user_assert(0) << "Unknown argument '" << name << "' specified via keyword."; + + found_kw_arg: + continue; + } + + // Verify all slots were filled. + for (size_t slot = 1; slot < argc; slot++) { + _halide_user_assert(argv[slot] != nullptr) << "Argument " << c_args[slot].name << " was not specified by either positional or keyword argument."; + } + } else { + // Everything should have been positional + _halide_user_assert(args.size() == argc - 1) + << "Expected exactly " << (argc - 1) << " positional arguments, but saw " << args.size() << "."; + } + + int result = c.call_argv_checked(argc, argv, cci); + _halide_user_assert(result == 0) << "Halide Runtime Error: " << result; + } + +#undef TYPED_ALLOCA +}; + +void define_callable(py::module &m) { + // Not supported yet, because we want to think about how to expose runtime + // overrides in Python (https://github.com/halide/Halide/issues/2790): + // - JITUserContext + + auto callable_class = + py::class_(m, "Callable") + .def("__call__", PyCallable::call_impl); +} + +} // namespace PythonBindings +} // namespace Halide diff --git a/python_bindings/src/halide/halide_/PyCallable.h b/python_bindings/src/halide/halide_/PyCallable.h new file mode 100644 index 000000000000..0e61df4811cc --- /dev/null +++ b/python_bindings/src/halide/halide_/PyCallable.h @@ -0,0 +1,14 @@ +#ifndef HALIDE_PYTHON_BINDINGS_PYCALLABLE_H +#define HALIDE_PYTHON_BINDINGS_PYCALLABLE_H + +#include "PyHalide.h" + +namespace Halide { +namespace PythonBindings { + +void define_callable(py::module &m); + +} // namespace PythonBindings +} // namespace Halide + +#endif // HALIDE_PYTHON_BINDINGS_PYCALLABLE_H diff --git a/python_bindings/src/PyConciseCasts.cpp b/python_bindings/src/halide/halide_/PyConciseCasts.cpp similarity index 79% rename from python_bindings/src/PyConciseCasts.cpp rename to python_bindings/src/halide/halide_/PyConciseCasts.cpp index 4ee79d703673..ebd646c0934b 100644 --- a/python_bindings/src/PyConciseCasts.cpp +++ b/python_bindings/src/halide/halide_/PyConciseCasts.cpp @@ -6,34 +6,34 @@ namespace PythonBindings { void define_concise_casts(py::module &m) { // explicit cast should be tried before // the pybind11::implicitly_convertible conversion - m.def("f64", [](double v) { + m.def("f64", [](double v) -> Expr { return Expr(v); }); - m.def("f32", [](float v) { + m.def("f32", [](float v) -> Expr { return Expr(v); }); - m.def("i64", [](int64_t v) { + m.def("i64", [](int64_t v) -> Expr { return Expr(v); }); - m.def("i32", [](int32_t v) { + m.def("i32", [](int32_t v) -> Expr { return Expr(v); }); - m.def("i16", [](int16_t v) { + m.def("i16", [](int16_t v) -> Expr { return Expr(v); }); - m.def("i8", [](int8_t v) { + m.def("i8", [](int8_t v) -> Expr { return Expr(v); }); - m.def("u64", [](uint64_t v) { + m.def("u64", [](uint64_t v) -> Expr { return Expr(v); }); - m.def("u32", [](uint32_t v) { + m.def("u32", [](uint32_t v) -> Expr { return Expr(v); }); - m.def("u16", [](uint16_t v) { + m.def("u16", [](uint16_t v) -> Expr { return Expr(v); }); - m.def("u8", [](uint8_t v) { + m.def("u8", [](uint8_t v) -> Expr { return Expr(v); }); // pybind11::implicitly_convertible conversions diff --git a/python_bindings/src/PyConciseCasts.h b/python_bindings/src/halide/halide_/PyConciseCasts.h similarity index 100% rename from python_bindings/src/PyConciseCasts.h rename to python_bindings/src/halide/halide_/PyConciseCasts.h diff --git a/python_bindings/src/PyDerivative.cpp b/python_bindings/src/halide/halide_/PyDerivative.cpp similarity index 100% rename from python_bindings/src/PyDerivative.cpp rename to python_bindings/src/halide/halide_/PyDerivative.cpp diff --git a/python_bindings/src/PyDerivative.h b/python_bindings/src/halide/halide_/PyDerivative.h similarity index 100% rename from python_bindings/src/PyDerivative.h rename to python_bindings/src/halide/halide_/PyDerivative.h diff --git a/python_bindings/src/PyEnums.cpp b/python_bindings/src/halide/halide_/PyEnums.cpp similarity index 88% rename from python_bindings/src/PyEnums.cpp rename to python_bindings/src/halide/halide_/PyEnums.cpp index 7961959d1a90..d3cc81022fa0 100644 --- a/python_bindings/src/PyEnums.cpp +++ b/python_bindings/src/halide/halide_/PyEnums.cpp @@ -82,6 +82,22 @@ void define_enums(py::module &m) { .value("RISCV", Target::Arch::RISCV) .value("WebAssembly", Target::Arch::WebAssembly); + // Please keep sorted. + py::enum_(m, "TargetProcessorTune") + .value("TuneAMDFam10", Target::Processor::AMDFam10) + .value("TuneBdVer1", Target::Processor::BdVer1) + .value("TuneBdVer2", Target::Processor::BdVer2) + .value("TuneBdVer3", Target::Processor::BdVer3) + .value("TuneBdVer4", Target::Processor::BdVer4) + .value("TuneBtVer1", Target::Processor::BtVer1) + .value("TuneBtVer2", Target::Processor::BtVer2) + .value("TuneGeneric", Target::Processor::ProcessorGeneric) + .value("TuneK8", Target::Processor::K8) + .value("TuneK8_SSE3", Target::Processor::K8_SSE3) + .value("TuneZnVer1", Target::Processor::ZnVer1) + .value("TuneZnVer2", Target::Processor::ZnVer2) + .value("TuneZnVer3", Target::Processor::ZnVer3); + py::enum_(m, "TargetFeature") .value("JIT", Target::Feature::JIT) .value("Debug", Target::Feature::Debug) @@ -110,7 +126,6 @@ void define_enums(py::module &m) { .value("OpenGLCompute", Target::Feature::OpenGLCompute) .value("EGL", Target::Feature::EGL) .value("UserContext", Target::Feature::UserContext) - .value("Matlab", Target::Feature::Matlab) .value("Profile", Target::Feature::Profile) .value("NoRuntime", Target::Feature::NoRuntime) .value("Metal", Target::Feature::Metal) @@ -141,6 +156,9 @@ void define_enums(py::module &m) { .value("HexagonDma", Target::Feature::HexagonDma) .value("EmbedBitcode", Target::Feature::EmbedBitcode) .value("EnableLLVMLoopOpt", Target::Feature::EnableLLVMLoopOpt) + // halide_target_feature_disable_llvm_loop_opt is deprecated in Halide 15 + // (and will be removed in Halide 16). Halide 15 now defaults to disabling + // LLVM loop optimization, unless halide_target_feature_enable_llvm_loop_opt is set. .value("DisableLLVMLoopOpt", Target::Feature::DisableLLVMLoopOpt) .value("WasmSimd128", Target::Feature::WasmSimd128) .value("WasmSignExt", Target::Feature::WasmSignExt) @@ -156,6 +174,7 @@ void define_enums(py::module &m) { .value("ARMv81a", Target::Feature::ARMv81a) .value("SanitizerCoverage", Target::Feature::SanitizerCoverage) .value("ProfileByTimer", Target::Feature::ProfileByTimer) + .value("SPIRV", Target::Feature::SPIRV) .value("FeatureEnd", Target::Feature::FeatureEnd); py::enum_(m, "TypeCode") diff --git a/python_bindings/src/PyEnums.h b/python_bindings/src/halide/halide_/PyEnums.h similarity index 100% rename from python_bindings/src/PyEnums.h rename to python_bindings/src/halide/halide_/PyEnums.h diff --git a/python_bindings/src/PyError.cpp b/python_bindings/src/halide/halide_/PyError.cpp similarity index 76% rename from python_bindings/src/PyError.cpp rename to python_bindings/src/halide/halide_/PyError.cpp index 714437025bd2..cbcc5b23fe3d 100644 --- a/python_bindings/src/PyError.cpp +++ b/python_bindings/src/halide/halide_/PyError.cpp @@ -36,6 +36,17 @@ void define_error(py::module &m) { handlers.custom_error = halide_python_error; handlers.custom_print = halide_python_print; Halide::Internal::JITSharedRuntime::set_default_handlers(handlers); + + static py::exception halide_error(m, "HalideError"); + py::register_exception_translator([](std::exception_ptr p) { // NOLINT + try { + if (p) { + std::rethrow_exception(p); + } + } catch (const Error &e) { + halide_error(e.what()); + } + }); } } // namespace PythonBindings diff --git a/python_bindings/src/PyError.h b/python_bindings/src/halide/halide_/PyError.h similarity index 100% rename from python_bindings/src/PyError.h rename to python_bindings/src/halide/halide_/PyError.h diff --git a/python_bindings/src/PyEvictionKey.cpp b/python_bindings/src/halide/halide_/PyEvictionKey.cpp similarity index 100% rename from python_bindings/src/PyEvictionKey.cpp rename to python_bindings/src/halide/halide_/PyEvictionKey.cpp diff --git a/python_bindings/src/PyEvictionKey.h b/python_bindings/src/halide/halide_/PyEvictionKey.h similarity index 100% rename from python_bindings/src/PyEvictionKey.h rename to python_bindings/src/halide/halide_/PyEvictionKey.h diff --git a/python_bindings/src/PyExpr.cpp b/python_bindings/src/halide/halide_/PyExpr.cpp similarity index 98% rename from python_bindings/src/PyExpr.cpp rename to python_bindings/src/halide/halide_/PyExpr.cpp index 423f2778e5a2..5f061cd8ff77 100644 --- a/python_bindings/src/PyExpr.cpp +++ b/python_bindings/src/halide/halide_/PyExpr.cpp @@ -48,6 +48,7 @@ void define_expr(py::module &m) { .def("__nonzero__", to_bool) .def("type", &Expr::type) + .def("defined", &Expr::defined) .def("__repr__", [](const Expr &e) -> std::string { std::ostringstream o; o << ""; diff --git a/python_bindings/src/PyExpr.h b/python_bindings/src/halide/halide_/PyExpr.h similarity index 100% rename from python_bindings/src/PyExpr.h rename to python_bindings/src/halide/halide_/PyExpr.h diff --git a/python_bindings/src/PyExternFuncArgument.cpp b/python_bindings/src/halide/halide_/PyExternFuncArgument.cpp similarity index 100% rename from python_bindings/src/PyExternFuncArgument.cpp rename to python_bindings/src/halide/halide_/PyExternFuncArgument.cpp diff --git a/python_bindings/src/PyExternFuncArgument.h b/python_bindings/src/halide/halide_/PyExternFuncArgument.h similarity index 100% rename from python_bindings/src/PyExternFuncArgument.h rename to python_bindings/src/halide/halide_/PyExternFuncArgument.h diff --git a/python_bindings/src/PyFunc.cpp b/python_bindings/src/halide/halide_/PyFunc.cpp similarity index 73% rename from python_bindings/src/PyFunc.cpp rename to python_bindings/src/halide/halide_/PyFunc.cpp index 87fc9d4aa574..bec3402fc0f6 100644 --- a/python_bindings/src/PyFunc.cpp +++ b/python_bindings/src/halide/halide_/PyFunc.cpp @@ -106,11 +106,14 @@ void define_func(py::module &m) { // - set_error_handler() // - set_custom_trace() // - set_custom_print() + // - JITUserContext auto func_class = py::class_(m, "Func") .def(py::init<>()) .def(py::init()) + .def(py::init(), py::arg("required_type"), py::arg("required_dimensions"), py::arg("name")) + .def(py::init, int, std::string>(), py::arg("required_types"), py::arg("required_dimensions"), py::arg("name")) .def(py::init()) .def(py::init([](Buffer<> &b) -> Func { return Func(b); })) @@ -125,15 +128,12 @@ void define_func(py::module &m) { }, py::arg("dst"), py::arg("target") = Target()) - // This will actually allow a list-of-buffers as well as a tuple-of-buffers, but that's OK. - .def( - "realize", - [](Func &f, std::vector> buffers, const Target &t) -> void { - py::gil_scoped_release release; - f.realize(Realization(buffers), t); - }, - py::arg("dst"), py::arg("target") = Target()) - + // It's important to have this overload of realize() go first: + // passing an empty list [] is ambiguous in Python, and could match to + // either list-of-sizes or list-of-buffers... but the former is useful + // (it allows realizing a 0-dimensional/scalar buffer) and the former is + // not (it will always assert-fail). Putting this one first allows it to + // be the first one chosen by the bindings in this case. .def( "realize", [](Func &f, const std::vector &sizes, const Target &target) -> py::object { @@ -146,6 +146,15 @@ void define_func(py::module &m) { }, py::arg("sizes") = std::vector{}, py::arg("target") = Target()) + // This will actually allow a list-of-buffers as well as a tuple-of-buffers, but that's OK. + .def( + "realize", + [](Func &f, std::vector> buffers, const Target &t) -> void { + py::gil_scoped_release release; + f.realize(Realization(std::move(buffers)), t); + }, + py::arg("dst"), py::arg("target") = Target()) + .def("defined", &Func::defined) .def("name", &Func::name) .def("get_schedule_dim_var_name", &Func::get_schedule_dim_var_name, py::arg("i")) @@ -157,7 +166,25 @@ void define_func(py::module &m) { }) .def("defined", &Func::defined) .def("outputs", &Func::outputs) - .def("output_types", &Func::output_types) + + .def("output_type", [](Func &f) { + // HALIDE_ATTRIBUTE_DEPRECATED("Func::output_type() is deprecated; call Func::type() instead.") + PyErr_WarnEx(PyExc_DeprecationWarning, + "Func.output_type() is deprecated; use Func.type() instead.", + 1); + return f.type(); + }) + + .def("output_types", [](Func &f) { + // HALIDE_ATTRIBUTE_DEPRECATED("Func::output_types() is deprecated; call Func::types() instead.") + PyErr_WarnEx(PyExc_DeprecationWarning, + "Func.output_types() is deprecated; use Func.types() instead.", + 1); + return f.types(); + }) + + .def("type", &Func::type) + .def("types", &Func::types) .def("bound", &Func::bound, py::arg("var"), py::arg("min"), py::arg("extent")) @@ -186,19 +213,19 @@ void define_func(py::module &m) { .def("compile_to", &Func::compile_to, py::arg("outputs"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_bitcode", (void (Func::*)(const std::string &, const std::vector &, const std::string &, const Target &target)) & Func::compile_to_bitcode, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_bitcode", (void (Func::*)(const std::string &, const std::vector &, const Target &target)) & Func::compile_to_bitcode, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment()) + .def("compile_to_bitcode", (void(Func::*)(const std::string &, const std::vector &, const std::string &, const Target &target)) & Func::compile_to_bitcode, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) + .def("compile_to_bitcode", (void(Func::*)(const std::string &, const std::vector &, const Target &target)) & Func::compile_to_bitcode, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment()) - .def("compile_to_llvm_assembly", (void (Func::*)(const std::string &, const std::vector &, const std::string &, const Target &target)) & Func::compile_to_llvm_assembly, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_llvm_assembly", (void (Func::*)(const std::string &, const std::vector &, const Target &target)) & Func::compile_to_llvm_assembly, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment()) + .def("compile_to_llvm_assembly", (void(Func::*)(const std::string &, const std::vector &, const std::string &, const Target &target)) & Func::compile_to_llvm_assembly, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) + .def("compile_to_llvm_assembly", (void(Func::*)(const std::string &, const std::vector &, const Target &target)) & Func::compile_to_llvm_assembly, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment()) - .def("compile_to_object", (void (Func::*)(const std::string &, const std::vector &, const std::string &, const Target &target)) & Func::compile_to_object, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_object", (void (Func::*)(const std::string &, const std::vector &, const Target &target)) & Func::compile_to_object, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment()) + .def("compile_to_object", (void(Func::*)(const std::string &, const std::vector &, const std::string &, const Target &target)) & Func::compile_to_object, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) + .def("compile_to_object", (void(Func::*)(const std::string &, const std::vector &, const Target &target)) & Func::compile_to_object, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment()) .def("compile_to_header", &Func::compile_to_header, py::arg("filename"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = get_target_from_environment()) - .def("compile_to_assembly", (void (Func::*)(const std::string &, const std::vector &, const std::string &, const Target &target)) & Func::compile_to_assembly, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_assembly", (void (Func::*)(const std::string &, const std::vector &, const Target &target)) & Func::compile_to_assembly, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment()) + .def("compile_to_assembly", (void(Func::*)(const std::string &, const std::vector &, const std::string &, const Target &target)) & Func::compile_to_assembly, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) + .def("compile_to_assembly", (void(Func::*)(const std::string &, const std::vector &, const Target &target)) & Func::compile_to_assembly, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment()) .def("compile_to_c", &Func::compile_to_c, py::arg("filename"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = get_target_from_environment()) @@ -216,6 +243,8 @@ void define_func(py::module &m) { .def("compile_jit", &Func::compile_jit, py::arg("target") = get_jit_target_from_environment()) + .def("compile_to_callable", &Func::compile_to_callable, py::arg("arguments"), py::arg("target") = get_jit_target_from_environment()) + .def("has_update_definition", &Func::has_update_definition) .def("num_update_definitions", &Func::num_update_definitions) @@ -242,13 +271,13 @@ void define_func(py::module &m) { .def("is_extern", &Func::is_extern) .def("extern_function_name", &Func::extern_function_name) - .def("define_extern", (void (Func::*)(const std::string &, const std::vector &, const std::vector &, const std::vector &, NameMangling, DeviceAPI)) & Func::define_extern, py::arg("function_name"), py::arg("params"), py::arg("types"), py::arg("arguments"), py::arg("mangling") = NameMangling::Default, py::arg("device_api") = DeviceAPI::Host) + .def("define_extern", (void(Func::*)(const std::string &, const std::vector &, const std::vector &, const std::vector &, NameMangling, DeviceAPI)) & Func::define_extern, py::arg("function_name"), py::arg("params"), py::arg("types"), py::arg("arguments"), py::arg("mangling") = NameMangling::Default, py::arg("device_api") = DeviceAPI::Host) - .def("define_extern", (void (Func::*)(const std::string &, const std::vector &, Type, int, NameMangling, DeviceAPI)) & Func::define_extern, py::arg("function_name"), py::arg("params"), py::arg("type"), py::arg("dimensionality"), py::arg("mangling") = NameMangling::Default, py::arg("device_api") = DeviceAPI::Host) + .def("define_extern", (void(Func::*)(const std::string &, const std::vector &, Type, int, NameMangling, DeviceAPI)) & Func::define_extern, py::arg("function_name"), py::arg("params"), py::arg("type"), py::arg("dimensionality"), py::arg("mangling") = NameMangling::Default, py::arg("device_api") = DeviceAPI::Host) - .def("define_extern", (void (Func::*)(const std::string &, const std::vector &, const std::vector &, int, NameMangling, DeviceAPI)) & Func::define_extern, py::arg("function_name"), py::arg("params"), py::arg("types"), py::arg("dimensionality"), py::arg("mangling") = NameMangling::Default, py::arg("device_api") = DeviceAPI::Host) + .def("define_extern", (void(Func::*)(const std::string &, const std::vector &, const std::vector &, int, NameMangling, DeviceAPI)) & Func::define_extern, py::arg("function_name"), py::arg("params"), py::arg("types"), py::arg("dimensionality"), py::arg("mangling") = NameMangling::Default, py::arg("device_api") = DeviceAPI::Host) - .def("define_extern", (void (Func::*)(const std::string &, const std::vector &, Type, const std::vector &, NameMangling, DeviceAPI)) & Func::define_extern, py::arg("function_name"), py::arg("params"), py::arg("type"), py::arg("arguments"), py::arg("mangling") = NameMangling::Default, py::arg("device_api") = DeviceAPI::Host) + .def("define_extern", (void(Func::*)(const std::string &, const std::vector &, Type, const std::vector &, NameMangling, DeviceAPI)) & Func::define_extern, py::arg("function_name"), py::arg("params"), py::arg("type"), py::arg("arguments"), py::arg("mangling") = NameMangling::Default, py::arg("device_api") = DeviceAPI::Host) .def("output_buffer", &Func::output_buffer) .def("output_buffers", &Func::output_buffers) @@ -266,7 +295,7 @@ void define_func(py::module &m) { try { std::vector> v = dst.cast>>(); - f.infer_input_bounds(Realization(v), target); + f.infer_input_bounds(Realization(std::move(v)), target); return; } catch (...) { // fall thru @@ -306,9 +335,6 @@ void define_func(py::module &m) { .def("fold_storage", &Func::fold_storage, py::arg("dim"), py::arg("extent"), py::arg("fold_forward") = true) - .def("compute_with", (Func & (Func::*)(LoopLevel, const std::vector> &)) & Func::compute_with, py::arg("loop_level"), py::arg("align")) - .def("compute_with", (Func & (Func::*)(LoopLevel, LoopAlignStrategy)) & Func::compute_with, py::arg("loop_level"), py::arg("align") = LoopAlignStrategy::Auto) - .def("infer_arguments", &Func::infer_arguments) .def("__repr__", [](const Func &func) -> std::string { @@ -350,7 +376,7 @@ void define_func(py::module &m) { define_set(func_class); define_set(func_class); define_set(func_class); - //define_set>(func_class); + // define_set>(func_class); // LHS(Expr, ...Expr) can only be LHS of an update definition. define_set(func_class); @@ -359,8 +385,6 @@ void define_func(py::module &m) { add_schedule_methods(func_class); - py::implicitly_convertible(); - define_stage(m); } diff --git a/python_bindings/src/PyFunc.h b/python_bindings/src/halide/halide_/PyFunc.h similarity index 100% rename from python_bindings/src/PyFunc.h rename to python_bindings/src/halide/halide_/PyFunc.h diff --git a/python_bindings/src/PyFuncRef.cpp b/python_bindings/src/halide/halide_/PyFuncRef.cpp similarity index 100% rename from python_bindings/src/PyFuncRef.cpp rename to python_bindings/src/halide/halide_/PyFuncRef.cpp diff --git a/python_bindings/src/PyFuncRef.h b/python_bindings/src/halide/halide_/PyFuncRef.h similarity index 100% rename from python_bindings/src/PyFuncRef.h rename to python_bindings/src/halide/halide_/PyFuncRef.h diff --git a/python_bindings/src/PyHalide.cpp b/python_bindings/src/halide/halide_/PyHalide.cpp similarity index 93% rename from python_bindings/src/PyHalide.cpp rename to python_bindings/src/halide/halide_/PyHalide.cpp index 02eb413902f8..8467db76cc0e 100644 --- a/python_bindings/src/PyHalide.cpp +++ b/python_bindings/src/halide/halide_/PyHalide.cpp @@ -3,6 +3,7 @@ #include "PyArgument.h" #include "PyBoundaryConditions.h" #include "PyBuffer.h" +#include "PyCallable.h" #include "PyConciseCasts.h" #include "PyDerivative.h" #include "PyEnums.h" @@ -14,7 +15,9 @@ #include "PyImageParam.h" #include "PyInlineReductions.h" #include "PyLambda.h" +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API #include "PyMachineParams.h" +#endif #include "PyModule.h" #include "PyParam.h" #include "PyPipeline.h" @@ -33,7 +36,7 @@ static_assert(PY_VERSION_HEX >= 0x03000000, "We appear to be compiling against Python 2.x rather than 3.x, which is not supported."); #ifndef HALIDE_PYBIND_MODULE_NAME -#define HALIDE_PYBIND_MODULE_NAME halide +#define HALIDE_PYBIND_MODULE_NAME halide_ #endif PYBIND11_MODULE(HALIDE_PYBIND_MODULE_NAME, m) { @@ -57,8 +60,11 @@ PYBIND11_MODULE(HALIDE_PYBIND_MODULE_NAME, m) { define_extern_func_argument(m); define_var(m); define_rdom(m); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API define_machine_params(m); +#endif define_module(m); + define_callable(m); define_func(m); define_pipeline(m); define_inline_reductions(m); diff --git a/python_bindings/src/PyHalide.h b/python_bindings/src/halide/halide_/PyHalide.h similarity index 100% rename from python_bindings/src/PyHalide.h rename to python_bindings/src/halide/halide_/PyHalide.h diff --git a/python_bindings/src/PyIROperator.cpp b/python_bindings/src/halide/halide_/PyIROperator.cpp similarity index 100% rename from python_bindings/src/PyIROperator.cpp rename to python_bindings/src/halide/halide_/PyIROperator.cpp diff --git a/python_bindings/src/PyIROperator.h b/python_bindings/src/halide/halide_/PyIROperator.h similarity index 100% rename from python_bindings/src/PyIROperator.h rename to python_bindings/src/halide/halide_/PyIROperator.h diff --git a/python_bindings/src/PyImageParam.cpp b/python_bindings/src/halide/halide_/PyImageParam.cpp similarity index 98% rename from python_bindings/src/PyImageParam.cpp rename to python_bindings/src/halide/halide_/PyImageParam.cpp index 1a3e35f50a4d..dfc6a8f66c7c 100644 --- a/python_bindings/src/PyImageParam.cpp +++ b/python_bindings/src/halide/halide_/PyImageParam.cpp @@ -57,7 +57,7 @@ void define_image_param(py::module &m) { auto image_param_class = py::class_(m, "ImageParam", output_image_param_class) .def(py::init<>()) - .def(py::init()) + .def(py::init(), py::arg("type"), py::arg("dimensions")) .def(py::init(), py::arg("type"), py::arg("dimensions"), py::arg("name")) .def("set", &ImageParam::set) .def("get", &ImageParam::get) diff --git a/python_bindings/src/PyImageParam.h b/python_bindings/src/halide/halide_/PyImageParam.h similarity index 100% rename from python_bindings/src/PyImageParam.h rename to python_bindings/src/halide/halide_/PyImageParam.h diff --git a/python_bindings/src/PyInlineReductions.cpp b/python_bindings/src/halide/halide_/PyInlineReductions.cpp similarity index 100% rename from python_bindings/src/PyInlineReductions.cpp rename to python_bindings/src/halide/halide_/PyInlineReductions.cpp diff --git a/python_bindings/src/PyInlineReductions.h b/python_bindings/src/halide/halide_/PyInlineReductions.h similarity index 100% rename from python_bindings/src/PyInlineReductions.h rename to python_bindings/src/halide/halide_/PyInlineReductions.h diff --git a/python_bindings/src/PyLambda.cpp b/python_bindings/src/halide/halide_/PyLambda.cpp similarity index 100% rename from python_bindings/src/PyLambda.cpp rename to python_bindings/src/halide/halide_/PyLambda.cpp diff --git a/python_bindings/src/PyLambda.h b/python_bindings/src/halide/halide_/PyLambda.h similarity index 100% rename from python_bindings/src/PyLambda.h rename to python_bindings/src/halide/halide_/PyLambda.h diff --git a/python_bindings/src/PyLoopLevel.cpp b/python_bindings/src/halide/halide_/PyLoopLevel.cpp similarity index 79% rename from python_bindings/src/PyLoopLevel.cpp rename to python_bindings/src/halide/halide_/PyLoopLevel.cpp index dade7619444d..b2db4b5d035d 100644 --- a/python_bindings/src/PyLoopLevel.cpp +++ b/python_bindings/src/halide/halide_/PyLoopLevel.cpp @@ -17,7 +17,9 @@ void define_loop_level(py::module &m) { .def_static("root", &LoopLevel::root) .def("__repr__", [](const LoopLevel &b) -> std::string { std::ostringstream o; - o << ""; + // b.to_string() fails for locked LoopLevels. Just output something generic. + // o << ""; + o << ""; return o.str(); }); } diff --git a/python_bindings/src/PyLoopLevel.h b/python_bindings/src/halide/halide_/PyLoopLevel.h similarity index 100% rename from python_bindings/src/PyLoopLevel.h rename to python_bindings/src/halide/halide_/PyLoopLevel.h diff --git a/python_bindings/src/PyMachineParams.cpp b/python_bindings/src/halide/halide_/PyMachineParams.cpp similarity index 95% rename from python_bindings/src/PyMachineParams.cpp rename to python_bindings/src/halide/halide_/PyMachineParams.cpp index e99dd594b11d..93c49d97fae6 100644 --- a/python_bindings/src/PyMachineParams.cpp +++ b/python_bindings/src/halide/halide_/PyMachineParams.cpp @@ -1,3 +1,4 @@ +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API #include "PyMachineParams.h" namespace Halide { @@ -23,3 +24,4 @@ void define_machine_params(py::module &m) { } // namespace PythonBindings } // namespace Halide +#endif diff --git a/python_bindings/src/PyMachineParams.h b/python_bindings/src/halide/halide_/PyMachineParams.h similarity index 86% rename from python_bindings/src/PyMachineParams.h rename to python_bindings/src/halide/halide_/PyMachineParams.h index aa15ee73c069..82b4ff3ac441 100644 --- a/python_bindings/src/PyMachineParams.h +++ b/python_bindings/src/halide/halide_/PyMachineParams.h @@ -1,6 +1,7 @@ #ifndef HALIDE_PYTHON_BINDINGS_PYMACHINEPARAMS_H #define HALIDE_PYTHON_BINDINGS_PYMACHINEPARAMS_H +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API #include "PyHalide.h" namespace Halide { @@ -11,4 +12,5 @@ void define_machine_params(py::module &m); } // namespace PythonBindings } // namespace Halide +#endif #endif // HALIDE_PYTHON_BINDINGS_PYMACHINEPARAMS_H diff --git a/python_bindings/src/PyModule.cpp b/python_bindings/src/halide/halide_/PyModule.cpp similarity index 90% rename from python_bindings/src/PyModule.cpp rename to python_bindings/src/halide/halide_/PyModule.cpp index 0d6b62104601..e46827de5115 100644 --- a/python_bindings/src/PyModule.cpp +++ b/python_bindings/src/halide/halide_/PyModule.cpp @@ -12,9 +12,13 @@ void define_module(py::module &m) { auto auto_scheduler_results_class = py::class_(m, "AutoSchedulerResults") .def(py::init<>()) - .def_readwrite("scheduler_name", &AutoSchedulerResults::scheduler_name) .def_readwrite("target", &AutoSchedulerResults::target) +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + .def_readwrite("scheduler_name", &AutoSchedulerResults::scheduler_name) .def_readwrite("machine_params_string", &AutoSchedulerResults::machine_params_string) +#else + .def_readwrite("autoscheduler_params", &AutoSchedulerResults::autoscheduler_params) +#endif .def_readwrite("schedule_source", &AutoSchedulerResults::schedule_source) .def_readwrite("python_schedule_source", &AutoSchedulerResults::python_schedule_source) .def_readwrite("featurization", &AutoSchedulerResults::featurization) @@ -33,8 +37,8 @@ void define_module(py::module &m) { .def("buffers", &Module::buffers) .def("submodules", &Module::submodules) - .def("append", (void (Module::*)(const Buffer<> &)) & Module::append, py::arg("buffer")) - .def("append", (void (Module::*)(const Module &)) & Module::append, py::arg("module")) + .def("append", (void(Module::*)(const Buffer<> &)) & Module::append, py::arg("buffer")) + .def("append", (void(Module::*)(const Module &)) & Module::append, py::arg("module")) .def("compile", &Module::compile, py::arg("outputs")) diff --git a/python_bindings/src/PyModule.h b/python_bindings/src/halide/halide_/PyModule.h similarity index 100% rename from python_bindings/src/PyModule.h rename to python_bindings/src/halide/halide_/PyModule.h diff --git a/python_bindings/src/PyParam.cpp b/python_bindings/src/halide/halide_/PyParam.cpp similarity index 91% rename from python_bindings/src/PyParam.cpp rename to python_bindings/src/halide/halide_/PyParam.cpp index 43379342e26a..5a80ca774ec7 100644 --- a/python_bindings/src/PyParam.cpp +++ b/python_bindings/src/halide/halide_/PyParam.cpp @@ -11,6 +11,12 @@ namespace { template void add_param_methods(py::class_> ¶m_class) { param_class + .def(py::init([](const Type &type, TYPE value) { + Param<> param(type); + param.set(value); + return param; + }), + py::arg("type"), py::arg("value")) .def(py::init([](const Type &type, const std::string &name, TYPE value) { Param<> param(type, name); param.set(value); diff --git a/python_bindings/src/PyParam.h b/python_bindings/src/halide/halide_/PyParam.h similarity index 100% rename from python_bindings/src/PyParam.h rename to python_bindings/src/halide/halide_/PyParam.h diff --git a/python_bindings/src/PyPipeline.cpp b/python_bindings/src/halide/halide_/PyPipeline.cpp similarity index 72% rename from python_bindings/src/PyPipeline.cpp rename to python_bindings/src/halide/halide_/PyPipeline.cpp index 16ba121e47fe..d775d2ef4b75 100644 --- a/python_bindings/src/PyPipeline.cpp +++ b/python_bindings/src/halide/halide_/PyPipeline.cpp @@ -42,6 +42,32 @@ void define_pipeline(py::module &m) { // - set_custom_trace() // - set_custom_print() +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API +// nothing +#else + py::class_(m, "AutoschedulerParams") + .def(py::init<>()) + .def(py::init(), py::arg("name")) + .def(py::init([](const std::string &name, const py::dict &extra) -> AutoschedulerParams { + // Manually convert the dict: + // we want to allow Python to pass in dicts that have non-string values for some keys; + // PyBind will reject these as a type failure. We'll stringify them here explicitly. + AutoschedulerParams asp(name); + for (auto item : extra) { + const std::string name = py::str(item.first).cast(); + const std::string value = py::str(item.second).cast(); + asp.extra[name] = value; + } + return asp; + }), + py::arg("target"), py::arg("autoscheduler_params")) + .def_readwrite("name", &AutoschedulerParams::name) + .def_readwrite("extra", &AutoschedulerParams::extra) + .def("__repr__", [](const AutoSchedulerResults &o) -> std::string { + return ""; + }); +#endif + auto pipeline_class = py::class_(m, "Pipeline") .def(py::init<>()) @@ -50,9 +76,10 @@ void define_pipeline(py::module &m) { .def("outputs", &Pipeline::outputs) - .def("auto_schedule", (AutoSchedulerResults(Pipeline::*)(const std::string &, const Target &, const MachineParams &)) & Pipeline::auto_schedule, +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + .def("auto_schedule", (AutoSchedulerResults(Pipeline::*)(const std::string &, const Target &, const MachineParams &) const) & Pipeline::auto_schedule, py::arg("autoscheduler_name"), py::arg("target"), py::arg("machine_params") = MachineParams::generic()) - .def("auto_schedule", (AutoSchedulerResults(Pipeline::*)(const Target &, const MachineParams &)) & Pipeline::auto_schedule, + .def("auto_schedule", (AutoSchedulerResults(Pipeline::*)(const Target &, const MachineParams &) const) & Pipeline::auto_schedule, py::arg("target"), py::arg("machine_params") = MachineParams::generic()) .def("apply_python_schedule", [](Pipeline& pipeline, const Target& target) { auto simplify_name = [](const std::string &s) -> std::string { @@ -76,7 +103,10 @@ void define_pipeline(py::module &m) { .def_static("set_default_autoscheduler_name", &Pipeline::set_default_autoscheduler_name, py::arg("autoscheduler_name")) - +#else + .def("apply_autoscheduler", (AutoSchedulerResults(Pipeline::*)(const Target &, const AutoschedulerParams &) const) & Pipeline::apply_autoscheduler, + py::arg("target"), py::arg("autoscheduler_params")) +#endif .def("get_func", &Pipeline::get_func, py::arg("index")) .def("print_loop_nest", &Pipeline::print_loop_nest) @@ -114,21 +144,21 @@ void define_pipeline(py::module &m) { .def("compile_jit", &Pipeline::compile_jit, py::arg("target") = get_jit_target_from_environment()) - .def( - "realize", [](Pipeline &p, Buffer<> buffer, const Target &target) -> void { - py::gil_scoped_release release; - p.realize(Realization(buffer), target); - }, - py::arg("dst"), py::arg("target") = Target()) + .def("compile_to_callable", &Pipeline::compile_to_callable, py::arg("arguments"), py::arg("target") = get_jit_target_from_environment()) - // This will actually allow a list-of-buffers as well as a tuple-of-buffers, but that's OK. .def( - "realize", [](Pipeline &p, std::vector> buffers, const Target &t) -> void { + "realize", [](Pipeline &p, Buffer<> buffer, const Target &target) -> void { py::gil_scoped_release release; - p.realize(Realization(buffers), t); + p.realize(Realization(std::move(buffer)), target); }, py::arg("dst"), py::arg("target") = Target()) + // It's important to have this overload of realize() go first: + // passing an empty list [] is ambiguous in Python, and could match to + // either list-of-sizes or list-of-buffers... but the former is useful + // (it allows realizing a 0-dimensional/scalar buffer) and the former is + // not (it will always assert-fail). Putting this one first allows it to + // be the first one chosen by the bindings in this case. .def( "realize", [](Pipeline &p, std::vector sizes, const Target &target) -> py::object { std::optional r; @@ -140,6 +170,14 @@ void define_pipeline(py::module &m) { }, py::arg("sizes") = std::vector{}, py::arg("target") = Target()) + // This will actually allow a list-of-buffers as well as a tuple-of-buffers, but that's OK. + .def( + "realize", [](Pipeline &p, std::vector> buffers, const Target &t) -> void { + py::gil_scoped_release release; + p.realize(Realization(std::move(buffers)), t); + }, + py::arg("dst"), py::arg("target") = Target()) + .def( "infer_input_bounds", [](Pipeline &p, const py::object &dst, const Target &target) -> void { // dst could be Buffer<>, vector, or vector @@ -153,7 +191,7 @@ void define_pipeline(py::module &m) { try { std::vector> v = dst.cast>>(); - p.infer_input_bounds(Realization(v), target); + p.infer_input_bounds(Realization(std::move(v)), target); return; } catch (...) { // fall thru @@ -189,6 +227,19 @@ void define_pipeline(py::module &m) { o << "]>"; return o.str(); }); + + // TODO: These should really live in PyGenerator.cpp once that lands + m.def( + "create_callable_from_generator", [](const GeneratorContext &context, const std::string &name, const std::map &generator_params) -> Callable { + return create_callable_from_generator(context, name, generator_params); + }, + py::arg("context"), py::arg("name"), py::arg("generator_params") = std::map{}); + + m.def( + "create_callable_from_generator", [](const Target &target, const std::string &name, const std::map &generator_params) -> Callable { + return create_callable_from_generator(target, name, generator_params); + }, + py::arg("target"), py::arg("name"), py::arg("generator_params") = std::map{}); } } // namespace PythonBindings diff --git a/python_bindings/src/PyPipeline.h b/python_bindings/src/halide/halide_/PyPipeline.h similarity index 100% rename from python_bindings/src/PyPipeline.h rename to python_bindings/src/halide/halide_/PyPipeline.h diff --git a/python_bindings/src/PyRDom.cpp b/python_bindings/src/halide/halide_/PyRDom.cpp similarity index 100% rename from python_bindings/src/PyRDom.cpp rename to python_bindings/src/halide/halide_/PyRDom.cpp diff --git a/python_bindings/src/PyRDom.h b/python_bindings/src/halide/halide_/PyRDom.h similarity index 100% rename from python_bindings/src/PyRDom.h rename to python_bindings/src/halide/halide_/PyRDom.h diff --git a/python_bindings/src/PyScheduleMethods.h b/python_bindings/src/halide/halide_/PyScheduleMethods.h similarity index 87% rename from python_bindings/src/PyScheduleMethods.h rename to python_bindings/src/halide/halide_/PyScheduleMethods.h index 8280ccd99ec5..9086bbafc5c0 100644 --- a/python_bindings/src/PyScheduleMethods.h +++ b/python_bindings/src/halide/halide_/PyScheduleMethods.h @@ -17,6 +17,10 @@ HALIDE_NEVER_INLINE void add_schedule_methods(PythonClass &class_instance) { py::arg("stage"), py::arg("var"), py::arg("align")) .def("compute_with", (T & (T::*)(const Stage &, const VarOrRVar &, LoopAlignStrategy)) & T::compute_with, py::arg("stage"), py::arg("var"), py::arg("align") = LoopAlignStrategy::Auto) + .def("compute_with", (T & (T::*)(LoopLevel, const std::vector> &)) & T::compute_with, + py::arg("loop_level"), py::arg("align")) + .def("compute_with", (T & (T::*)(LoopLevel, LoopAlignStrategy)) & T::compute_with, + py::arg("loop_level"), py::arg("align") = LoopAlignStrategy::Auto) .def("unroll", (T & (T::*)(const VarOrRVar &)) & T::unroll, py::arg("var")) @@ -88,26 +92,6 @@ HALIDE_NEVER_INLINE void add_schedule_methods(PythonClass &class_instance) { .def("hexagon", &T::hexagon, py::arg("x") = Var::outermost()) - .def( - "prefetch", [](T &t, const Func &f, const VarOrRVar &var, int offset, PrefetchBoundStrategy strategy) -> T & { - // HALIDE_ATTRIBUTE_DEPRECATED("Call prefetch() with the two-var form instead.") - PyErr_WarnEx(PyExc_DeprecationWarning, - "Call prefetch() with the two-var form instead.", - 1); - return t.prefetch(f, var, var, offset, strategy); - }, - py::arg("image"), py::arg("var"), py::arg("offset") = 1, py::arg("strategy") = PrefetchBoundStrategy::GuardWithIf) - .def( - "prefetch", [](T &t, const ImageParam &image, const VarOrRVar &var, int offset, PrefetchBoundStrategy strategy) -> T & { - // HALIDE_ATTRIBUTE_DEPRECATED("Call prefetch() with the two-var form instead.") - PyErr_WarnEx(PyExc_DeprecationWarning, - "Call prefetch() with the two-var form instead.", - 1); - // Templated function; specializing only on ImageParam for now - return t.template prefetch(image, var, var, offset, strategy); - }, - py::arg("image"), py::arg("var"), py::arg("offset") = 1, py::arg("strategy") = PrefetchBoundStrategy::GuardWithIf) - .def("prefetch", (T & (T::*)(const Func &, const VarOrRVar &, const VarOrRVar &, Expr, PrefetchBoundStrategy)) & T::prefetch, py::arg("func"), py::arg("at"), py::arg("from"), py::arg("offset") = 1, py::arg("strategy") = PrefetchBoundStrategy::GuardWithIf) .def( "prefetch", [](T &t, const ImageParam &image, const VarOrRVar &at, const VarOrRVar &from, const Expr &offset, PrefetchBoundStrategy strategy) -> T & { diff --git a/python_bindings/src/PyStage.cpp b/python_bindings/src/halide/halide_/PyStage.cpp similarity index 58% rename from python_bindings/src/PyStage.cpp rename to python_bindings/src/halide/halide_/PyStage.cpp index 406183955f1d..4d107062a1ae 100644 --- a/python_bindings/src/PyStage.cpp +++ b/python_bindings/src/halide/halide_/PyStage.cpp @@ -8,6 +8,9 @@ namespace PythonBindings { void define_stage(py::module &m) { auto stage_class = py::class_(m, "Stage") + // for implicitly_convertible + .def(py::init([](const Func &f) -> Stage { return f; })) + .def("dump_argument_list", &Stage::dump_argument_list) .def("name", &Stage::name) .def("get_schedule_dim_var_name", &Stage::get_schedule_dim_var_name, py::arg("i")) @@ -15,13 +18,10 @@ void define_stage(py::module &m) { .def("rfactor", (Func(Stage::*)(std::vector>)) & Stage::rfactor, py::arg("preserved")) .def("rfactor", (Func(Stage::*)(const RVar &, const Var &)) & Stage::rfactor, - py::arg("r"), py::arg("v")) + py::arg("r"), py::arg("v")); + + py::implicitly_convertible(); - // These two variants of compute_with are specific to Stage - .def("compute_with", (Stage & (Stage::*)(LoopLevel, const std::vector> &)) & Stage::compute_with, - py::arg("loop_level"), py::arg("align")) - .def("compute_with", (Stage & (Stage::*)(LoopLevel, LoopAlignStrategy)) & Stage::compute_with, - py::arg("loop_level"), py::arg("align") = LoopAlignStrategy::Auto); add_schedule_methods(stage_class); } diff --git a/python_bindings/src/PyStage.h b/python_bindings/src/halide/halide_/PyStage.h similarity index 100% rename from python_bindings/src/PyStage.h rename to python_bindings/src/halide/halide_/PyStage.h diff --git a/python_bindings/src/PyTarget.cpp b/python_bindings/src/halide/halide_/PyTarget.cpp similarity index 90% rename from python_bindings/src/PyTarget.cpp rename to python_bindings/src/halide/halide_/PyTarget.cpp index 718936332ea9..25a822948709 100644 --- a/python_bindings/src/PyTarget.cpp +++ b/python_bindings/src/halide/halide_/PyTarget.cpp @@ -24,7 +24,9 @@ void define_target(py::module &m) { .def(py::init<>()) .def(py::init()) .def(py::init()) + .def(py::init()) .def(py::init>()) + .def(py::init>()) .def("__eq__", [](const Target &value, Target *value2) { return value2 && value == *value2; }) .def("__ne__", [](const Target &value, Target *value2) { return !value2 || value != *value2; }) @@ -32,12 +34,13 @@ void define_target(py::module &m) { .def_readwrite("os", &Target::os) .def_readwrite("arch", &Target::arch) .def_readwrite("bits", &Target::bits) + .def_readwrite("processor_tune", &Target::processor_tune) .def("__repr__", &target_repr) .def("__str__", &Target::to_string) .def("to_string", &Target::to_string) - .def("has_feature", (bool (Target::*)(Target::Feature) const) & Target::has_feature) + .def("has_feature", (bool(Target::*)(Target::Feature) const) & Target::has_feature) .def("features_any_of", &Target::features_any_of, py::arg("features")) .def("features_all_of", &Target::features_all_of, py::arg("features")) diff --git a/python_bindings/src/PyTarget.h b/python_bindings/src/halide/halide_/PyTarget.h similarity index 100% rename from python_bindings/src/PyTarget.h rename to python_bindings/src/halide/halide_/PyTarget.h diff --git a/python_bindings/src/PyTuple.cpp b/python_bindings/src/halide/halide_/PyTuple.cpp similarity index 100% rename from python_bindings/src/PyTuple.cpp rename to python_bindings/src/halide/halide_/PyTuple.cpp diff --git a/python_bindings/src/PyTuple.h b/python_bindings/src/halide/halide_/PyTuple.h similarity index 100% rename from python_bindings/src/PyTuple.h rename to python_bindings/src/halide/halide_/PyTuple.h diff --git a/python_bindings/src/PyType.cpp b/python_bindings/src/halide/halide_/PyType.cpp similarity index 92% rename from python_bindings/src/PyType.cpp rename to python_bindings/src/halide/halide_/PyType.cpp index d71c7fabf3ca..0f8c383bdf33 100644 --- a/python_bindings/src/PyType.cpp +++ b/python_bindings/src/halide/halide_/PyType.cpp @@ -76,12 +76,12 @@ void define_type(py::module &m) { // .def("__lt__", [](const Type &value, Type *value2) -> bool { return value2 && value < *value2; }) .def("element_of", &Type::element_of) - .def("can_represent", (bool (Type::*)(Type) const) & Type::can_represent, py::arg("other")) + .def("can_represent", (bool(Type::*)(Type) const) & Type::can_represent, py::arg("other")) // Python doesn't have unsigned integers -- all integers are signed -- // so we'll never see anything that can usefully be routed to the uint64_t // overloads of these methods. - .def("is_max", (bool (Type::*)(int64_t) const) & Type::is_max, py::arg("value")) - .def("is_min", (bool (Type::*)(int64_t) const) & Type::is_min, py::arg("value")) + .def("is_max", (bool(Type::*)(int64_t) const) & Type::is_max, py::arg("value")) + .def("is_min", (bool(Type::*)(int64_t) const) & Type::is_min, py::arg("value")) .def("max", &Type::max) .def("min", &Type::min) .def("__repr__", &type_repr) diff --git a/python_bindings/src/PyType.h b/python_bindings/src/halide/halide_/PyType.h similarity index 100% rename from python_bindings/src/PyType.h rename to python_bindings/src/halide/halide_/PyType.h diff --git a/python_bindings/src/PyVar.cpp b/python_bindings/src/halide/halide_/PyVar.cpp similarity index 84% rename from python_bindings/src/PyVar.cpp rename to python_bindings/src/halide/halide_/PyVar.cpp index 25b2981df470..88d08a64e500 100644 --- a/python_bindings/src/PyVar.cpp +++ b/python_bindings/src/halide/halide_/PyVar.cpp @@ -22,9 +22,9 @@ void define_var(py::module &m) { .def(py::init()) .def("name", &Var::name) .def("same_as", &Var::same_as) - .def("is_implicit", (bool (Var::*)() const) & Var::is_implicit) - .def("implicit_index", (int (Var::*)() const) & Var::implicit_index) - .def("is_placeholder", (bool (Var::*)() const) & Var::is_placeholder) + .def("is_implicit", (bool(Var::*)() const) & Var::is_implicit) + .def("implicit_index", (int(Var::*)() const) & Var::implicit_index) + .def("is_placeholder", (bool(Var::*)() const) & Var::is_placeholder) .def_static("implicit", (Var(*)(int)) & Var::implicit) .def_static("outermost", &Var::outermost) .def("__repr__", &var_repr) diff --git a/python_bindings/src/PyVar.h b/python_bindings/src/halide/halide_/PyVar.h similarity index 100% rename from python_bindings/src/PyVar.h rename to python_bindings/src/halide/halide_/PyVar.h diff --git a/python_bindings/src/PyVarOrRVar.cpp b/python_bindings/src/halide/halide_/PyVarOrRVar.cpp similarity index 100% rename from python_bindings/src/PyVarOrRVar.cpp rename to python_bindings/src/halide/halide_/PyVarOrRVar.cpp diff --git a/python_bindings/src/PyVarOrRVar.h b/python_bindings/src/halide/halide_/PyVarOrRVar.h similarity index 100% rename from python_bindings/src/PyVarOrRVar.h rename to python_bindings/src/halide/halide_/PyVarOrRVar.h diff --git a/python_bindings/stub/AddHalideGeneratorPython.cmake b/python_bindings/stub/AddHalideGeneratorPython.cmake deleted file mode 100644 index 61322dedaf32..000000000000 --- a/python_bindings/stub/AddHalideGeneratorPython.cmake +++ /dev/null @@ -1,14 +0,0 @@ -set(HALIDE_PYSTUB_CPP_PATH ${CMAKE_CURRENT_LIST_DIR}/PyStub.cpp) - -function(add_generator_python TARGET) - set(options) - set(oneValueArgs) - set(multiValueArgs) - cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - - Python3_add_library(${TARGET} MODULE ${HALIDE_PYSTUB_CPP_PATH} ${args_UNPARSED_ARGUMENTS}) - target_compile_definitions(${TARGET} PRIVATE - "HALIDE_PYSTUB_GENERATOR_NAME=${TARGET}" - "HALIDE_PYSTUB_MODULE_NAME=${TARGET}") - target_link_libraries(${TARGET} PRIVATE Halide::PyStubs) -endfunction() diff --git a/python_bindings/stub/CMakeLists.txt b/python_bindings/stub/CMakeLists.txt index fedfda44e3df..a4f69fc044dd 100644 --- a/python_bindings/stub/CMakeLists.txt +++ b/python_bindings/stub/CMakeLists.txt @@ -6,8 +6,7 @@ if (NOT TARGET Halide::PyStubs) target_link_libraries(Halide_PyStubs PUBLIC Halide::Halide) set_target_properties(Halide_PyStubs PROPERTIES EXPORT_NAME PyStubs + CXX_VISIBILITY_PRESET hidden VISIBILITY_INLINES_HIDDEN TRUE POSITION_INDEPENDENT_CODE ON) endif () - -include(${CMAKE_CURRENT_LIST_DIR}/AddHalideGeneratorPython.cmake) diff --git a/python_bindings/stub/PyStub.cpp b/python_bindings/stub/PyStub.cpp index 5064295557b0..251ae5f367a0 100644 --- a/python_bindings/stub/PyStub.cpp +++ b/python_bindings/stub/PyStub.cpp @@ -28,24 +28,16 @@ extern "C" PyObject *_halide_pystub_impl(const char *module_name, const Halide:: #define _HALIDE_CONCAT(first, second) first##second #define HALIDE_CONCAT(first, second) _HALIDE_CONCAT(first, second) -// Don't use HALIDE_EXPORT: Halide.h will already have defined it, -// but it might be defined the wrong way (import rather than export). -#if defined(WIN32) || defined(_WIN32) -#define HALIDE_PLUGIN_EXPORT __declspec(dllexport) -#else -#define HALIDE_PLUGIN_EXPORT __attribute__((visibility("default"))) -#endif - static_assert(PY_MAJOR_VERSION >= 3, "Python bindings for Halide require Python 3+"); -#define _HALIDE_PLUGIN_IMPL(name) extern "C" HALIDE_PLUGIN_EXPORT PyObject *PyInit_##name() /* NOLINT(bugprone-macro-parentheses) */ +#define _HALIDE_PLUGIN_IMPL(name) extern "C" HALIDE_EXPORT_SYMBOL PyObject *PyInit_##name() /* NOLINT(bugprone-macro-parentheses) */ #define HALIDE_PLUGIN_IMPL(name) _HALIDE_PLUGIN_IMPL(name) // clang-format off namespace halide_register_generator { namespace HALIDE_CONCAT(HALIDE_PYSTUB_GENERATOR_NAME, _ns) { - extern std::unique_ptr factory(const Halide::GeneratorContext &context); + extern std::unique_ptr factory(const Halide::GeneratorContext &context); } // namespace HALIDE_CONCAT(HALIDE_PYSTUB_GENERATOR_NAME,_ns) } // namespace halide_register_generator diff --git a/python_bindings/stub/PyStubImpl.cpp b/python_bindings/stub/PyStubImpl.cpp index b5803cbd7fd1..143610d9d2cc 100644 --- a/python_bindings/stub/PyStubImpl.cpp +++ b/python_bindings/stub/PyStubImpl.cpp @@ -24,9 +24,10 @@ namespace py = pybind11; namespace Halide { namespace PythonBindings { +using Parameter = Internal::Parameter; +using ArgInfoKind = Internal::ArgInfoKind; +using ArgInfo = Internal::AbstractGenerator::ArgInfo; using GeneratorFactory = Internal::GeneratorFactory; -using GeneratorParamsMap = Internal::GeneratorParamsMap; -using Stub = Internal::GeneratorStub; using StubInput = Internal::StubInput; using StubInputBuffer = Internal::StubInputBuffer; @@ -40,12 +41,14 @@ void halide_python_error(JITUserContext *, const char *msg) { } void halide_python_print(JITUserContext *, const char *msg) { + py::gil_scoped_acquire acquire; py::print(msg, py::arg("end") = ""); } class HalidePythonCompileTimeErrorReporter : public CompileTimeErrorReporter { public: void warning(const char *msg) override { + py::gil_scoped_acquire acquire; py::print(msg, py::arg("end") = ""); } @@ -63,6 +66,21 @@ void install_error_handlers(py::module &m) { handlers.custom_error = halide_python_error; handlers.custom_print = halide_python_print; Halide::Internal::JITSharedRuntime::set_default_handlers(handlers); + + static py::object halide_error = py::module_::import("halide").attr("HalideError"); + if (halide_error.is(py::none())) { + throw std::runtime_error("Could not find halide.HalideError"); + } + + py::register_exception_translator([](std::exception_ptr p) { // NOLINT + try { + if (p) { + std::rethrow_exception(p); + } + } catch (const Error &e) { + PyErr_SetString(halide_error.ptr(), e.what()); + } + }); } // Anything that defines __getitem__ looks sequencelike to pybind, @@ -71,127 +89,184 @@ bool is_real_sequence(const py::object &o) { return py::isinstance(o) && py::hasattr(o, "__len__"); } -StubInput to_stub_input(const py::object &o) { - // Don't use isinstance: we want to get things that - // can be implicitly converted as well (eg ImageParam -> Func) - try { - return StubInput(StubInputBuffer(o.cast>())); - } catch (...) { - // Not convertible to Buffer. Fall thru and try next. +template +struct cast_error_string { + std::string operator()(const py::handle &h, const std::string &name) { + return "Unable to cast Input " + name + " to " + py::type_id() + " from " + (std::string)py::str(py::type::handle_of(h)); } +}; + +template<> +std::string cast_error_string>::operator()(const py::handle &h, const std::string &name) { + std::ostringstream o; + o << "Input " << name << " requires an ImageParam or Buffer argument when using generate(), but saw " << (std::string)py::str(py::type::handle_of(h)); + return o.str(); +} + +template<> +std::string cast_error_string::operator()(const py::handle &h, const std::string &name) { + std::ostringstream o; + o << "Input " << name << " requires a Func argument when using generate(), but saw " << (std::string)py::str(py::type::handle_of(h)); + return o.str(); +} + +template<> +std::string cast_error_string::operator()(const py::handle &h, const std::string &name) { + std::ostringstream o; + o << "Input " << name << " requires a Param (or scalar literal) argument when using generate(), but saw " << (std::string)py::str(py::type::handle_of(h)); + return o.str(); +} +template +T cast_to(const py::handle &h, const std::string &name) { + // We want to ensure that the error thrown is one that will be translated + // to `hl.HalideError` in Python. try { - return StubInput(o.cast()); - } catch (...) { - // Not convertible to Func. Fall thru and try next. + return h.cast(); + } catch (const std::exception &e) { + throw Halide::Error(cast_error_string()(h, name)); } +} - return StubInput(o.cast()); +template<> +Parameter cast_to(const py::handle &h, const std::string &name) { + auto b = cast_to>(h, name); + Parameter p(b.type(), true, b.dimensions()); + p.set_buffer(b); + return p; } -std::vector to_stub_inputs(const py::object &value) { +template +std::vector to_input_vector(const py::object &value, const std::string &name) { + std::vector v; if (is_real_sequence(value)) { - std::vector v; for (const auto &o : py::reinterpret_borrow(value)) { - v.push_back(to_stub_input(o)); + v.push_back(cast_to(o, name)); } - return v; } else { - return {to_stub_input(value)}; + v.push_back(cast_to(value, name)); } + return v; } -py::object generate_impl(const GeneratorFactory &factory, const GeneratorContext &context, const py::args &args, const py::kwargs &kwargs) { - Stub stub(context, [factory](const GeneratorContext &context) -> std::unique_ptr { - return factory(context); - }); - auto names = stub.get_names(); - _halide_user_assert(!names.outputs.empty()) - << "Generators that use build() (instead of generate()+Output<>) are not supported in the Python bindings."; +py::object generate_impl(const GeneratorFactory &factory, + const GeneratorContext &context, + const py::args &args, + const py::kwargs &kwargs) { + auto generator = factory(context); + + const auto arg_infos = generator->arginfos(); + std::vector input_arguments, output_arguments; + std::map input_arguments_map; + std::set inputs_seen; + for (const auto &a : arg_infos) { + if (a.dir == Internal::ArgInfoDirection::Input) { + input_arguments.push_back(a); + input_arguments_map[a.name] = a; + } else { + output_arguments.push_back(a); + } + } + size_t kw_inputs_specified = 0; + + // GeneratorParams are always specified as an optional named parameter + // called "generator_params", which is expected to be a python dict. + // If generatorparams are specified, do them first, before any Inputs. + if (kwargs.contains("generator_params")) { + py::dict gp = py::cast(kwargs["generator_params"]); + for (auto item : gp) { + const std::string gp_name = py::str(item.first).cast(); + const py::handle gp_value = item.second; + if (py::isinstance(gp_value)) { + generator->set_generatorparam_value(gp_name, gp_value.cast()); + } else if (py::isinstance(gp_value)) { + // Convert [hl.UInt(8), hl.Int(16)] -> uint8,int16 + std::string v; + for (auto t : gp_value) { + if (!v.empty()) { + v += ","; + } + v += py::str(t).cast(); + } + generator->set_generatorparam_value(gp_name, v); + } else { + generator->set_generatorparam_value(gp_name, py::str(gp_value).cast()); + } + } + } // Inputs can be specified by either positional or named args, // but may not be mixed. (i.e., if any inputs are specified as a named - // arg, they all must be specified that way; otherwise they must all be + // argument, they all must be specified that way; otherwise they must all be // positional, in the order declared in the Generator.) - // - // GeneratorParams can only be specified by name, and are always optional. - std::map> kw_inputs; - for (const auto &name : names.inputs) { - _halide_user_assert(kw_inputs.count(name) == 0); // internal error - kw_inputs[name] = std::vector{}; - } - size_t kw_inputs_specified = 0; - - GeneratorParamsMap generator_params; + const auto bind_one = [&generator](py::handle h, const ArgInfo &a) { + py::object o = py::cast(h); + if (a.kind == ArgInfoKind::Buffer) { + generator->bind_input(a.name, to_input_vector(o, a.name)); + } else if (a.kind == ArgInfoKind::Function) { + generator->bind_input(a.name, to_input_vector(o, a.name)); + } else { + generator->bind_input(a.name, to_input_vector(o, a.name)); + } + }; - // Process the kwargs first. for (auto kw : kwargs) { - // If the kwarg is the name of a known input, stick it in the input - // vector. If not, stick it in the GeneratorParamsMap (if it's invalid, - // an error will be reported further downstream). - std::string name = kw.first.cast(); - py::handle value = kw.second; - auto it = kw_inputs.find(name); - if (it != kw_inputs.end()) { - _halide_user_assert(it->second.empty()) - << "Generator Input named '" << it->first << "' was specified more than once."; - it->second = to_stub_inputs(py::cast(value)); - kw_inputs_specified++; - } else { - if (py::isinstance(value)) { - generator_params[name] = value.cast(); - } else { - generator_params[name] = py::str(value).cast(); - } + const std::string name = kw.first.cast(); + const py::handle value = kw.second; + + if (name == "generator_params") { + continue; } - } - std::vector> inputs; - inputs.reserve(names.inputs.size()); + auto it = input_arguments_map.find(name); + _halide_user_assert(it != input_arguments_map.end()) << "Unknown input '" << name << "' specified via keyword argument."; + _halide_user_assert(inputs_seen.count(name) == 0) << "Input " << name << " specified multiple times."; + inputs_seen.insert(name); + + const auto &a = it->second; + bind_one(value, a); + kw_inputs_specified++; + } if (args.empty()) { // No arguments specified positionally, so they must all be via keywords. - _halide_user_assert(kw_inputs_specified == names.inputs.size()) - << "Expected exactly " << names.inputs.size() << " keyword args for inputs, but saw " << kw_inputs_specified << "."; - for (const auto &name : names.inputs) { - inputs.push_back(std::move(kw_inputs[name])); - } + _halide_user_assert(kw_inputs_specified == input_arguments.size()) + << "Expected exactly " << input_arguments.size() << " keyword args for inputs, but saw " << kw_inputs_specified << "."; } else { // Some positional arguments, so all inputs must be positional (and none via keyword). - _halide_user_assert(kw_inputs_specified == 0) - << "Cannot use both positional and keyword arguments for inputs."; - _halide_user_assert(args.size() == names.inputs.size()) - << "Expected exactly " << names.inputs.size() << " positional args for inputs, but saw " << args.size() << "."; - for (auto arg : args) { - inputs.push_back(to_stub_inputs(py::cast(arg))); + _halide_user_assert(kw_inputs_specified == 0) << "Cannot use both positional and keyword arguments for inputs."; + _halide_user_assert(args.size() == input_arguments.size()) + << "Expected exactly " << input_arguments.size() << " positional args for inputs, but saw " << args.size() << "."; + for (size_t i = 0; i < args.size(); i++) { + const auto &a = input_arguments[i]; + _halide_user_assert(inputs_seen.count(a.name) == 0) << "Input " << a.name << " specified multiple times."; + inputs_seen.insert(a.name); + bind_one(args[i], a); } } - // Verify everything is there - _halide_user_assert(inputs.size() == names.inputs.size()); - for (size_t i = 0; i < inputs.size(); ++i) { - _halide_user_assert(!inputs[i].empty()) - << "Generator Input named '" << names.inputs[i] << "' was not specified."; - } + generator->build_pipeline(); - const std::vector> outputs = stub.generate(generator_params, inputs); + const size_t outputs_size = output_arguments.size(); + py::tuple py_outputs(outputs_size); + for (size_t i = 0; i < outputs_size; i++) { + std::vector outputs = generator->output_func(output_arguments[i].name); - py::tuple py_outputs(outputs.size()); - for (size_t i = 0; i < outputs.size(); i++) { py::object o; - if (outputs[i].size() == 1) { + if (outputs.size() == 1) { // convert list-of-1 into single element - o = py::cast(outputs[i][0]); + o = py::cast(outputs[0]); } else { - o = py::cast(outputs[i]); + o = py::cast(outputs); } - if (outputs.size() == 1) { - // bail early, return the single object rather than a dict + if (outputs_size == 1) { + // bail early, returning the single object rather than a dict return o; } py_outputs[i] = o; } + // An explicit "std::move" is needed here because there's // an implicit tuple->object conversion that inhibits it otherwise. return std::move(py_outputs); diff --git a/python_bindings/stub/ext.ldscript.apple.in b/python_bindings/stub/ext.ldscript.apple.in new file mode 100644 index 000000000000..90695315ccd6 --- /dev/null +++ b/python_bindings/stub/ext.ldscript.apple.in @@ -0,0 +1 @@ +_PyInit_${SYMBOL} diff --git a/python_bindings/stub/ext.ldscript.linux.in b/python_bindings/stub/ext.ldscript.linux.in new file mode 100644 index 000000000000..b426c20b08b8 --- /dev/null +++ b/python_bindings/stub/ext.ldscript.linux.in @@ -0,0 +1,4 @@ +{ +global: PyInit_${SYMBOL}; +local: *; +}; diff --git a/python_bindings/test/CMakeLists.txt b/python_bindings/test/CMakeLists.txt new file mode 100644 index 000000000000..e0f9c76f5470 --- /dev/null +++ b/python_bindings/test/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(apps) +add_subdirectory(correctness) +add_subdirectory(generators) diff --git a/python_bindings/test/apps/CMakeLists.txt b/python_bindings/test/apps/CMakeLists.txt new file mode 100644 index 000000000000..70daeb483fd6 --- /dev/null +++ b/python_bindings/test/apps/CMakeLists.txt @@ -0,0 +1,16 @@ +set(tests + bilateral_grid.py + blur.py + erode.py + interpolate.py + local_laplacian.py) + +foreach (test IN LISTS tests) + add_python_test( + FILE "${test}" + LABEL python_apps + ENVIRONMENT + "TEST_TMPDIR=$" + "TEST_IMAGES_DIR=$" + ) +endforeach () diff --git a/python_bindings/apps/bilateral_grid.py b/python_bindings/test/apps/bilateral_grid.py similarity index 89% rename from python_bindings/apps/bilateral_grid.py rename to python_bindings/test/apps/bilateral_grid.py index 24a4b038361a..3c0e0376e4e3 100644 --- a/python_bindings/apps/bilateral_grid.py +++ b/python_bindings/test/apps/bilateral_grid.py @@ -9,6 +9,18 @@ import imageio import os.path +# Return the directory to look in for test images: +# - If TEST_IMAGES_DIR is defined, use that +# - Otherwise, create a relative path to the C++ apps/images dir +def apps_images_dir(): + return os.environ.get("TEST_IMAGES_DIR", os.path.join(os.path.dirname(__file__), "../../apps/images")) + +# Return the directory to use when writing output files: +# - If TEST_TMPDIR is defined, use that +# - Otherwise, return an empty string (i.e., relative to whatever the current directory is) +def apps_output_dir(): + return os.environ.get("TEST_TMPDIR", "") + def get_bilateral_grid(input, r_sigma, s_sigma, aot=False): x = hl.Var('x') y = hl.Var('y') @@ -102,7 +114,7 @@ def generate_compiled_file(bilateral_grid): def get_input_data(): - image_path = os.path.join(os.path.dirname(__file__), "../../apps/images/rgb.png") + image_path = os.path.join(apps_images_dir(), "rgb.png") assert os.path.exists(image_path), \ "Could not find %s" % image_path rgb_data = imageio.imread(image_path) @@ -133,8 +145,8 @@ def filter_test_image(bilateral_grid, input): output_image.copy_to_host() # save results - input_path = "bilateral_grid_input.png" - output_path = "bilateral_grid.png" + input_path = os.path.join(apps_output_dir(), "bilateral_grid_input.png") + output_path = os.path.join(apps_output_dir(), "bilateral_grid.png") imageio.imsave(input_path, input_data) imageio.imsave(output_path, output_data) print("\nbilateral_grid realized on output_image.") diff --git a/python_bindings/apps/blur.py b/python_bindings/test/apps/blur.py similarity index 73% rename from python_bindings/apps/blur.py rename to python_bindings/test/apps/blur.py index 73396ab61430..d89383166763 100644 --- a/python_bindings/apps/blur.py +++ b/python_bindings/test/apps/blur.py @@ -4,6 +4,18 @@ import imageio import os.path +# Return the directory to look in for test images: +# - If TEST_IMAGES_DIR is defined, use that +# - Otherwise, create a relative path to the C++ apps/images dir +def apps_images_dir(): + return os.environ.get("TEST_IMAGES_DIR", os.path.join(os.path.dirname(__file__), "../../apps/images")) + +# Return the directory to use when writing output files: +# - If TEST_TMPDIR is defined, use that +# - Otherwise, return an empty string (i.e., relative to whatever the current directory is) +def apps_output_dir(): + return os.environ.get("TEST_TMPDIR", "") + def get_blur(input): assert type(input) == hl.ImageParam assert input.dimensions() == 2 @@ -31,7 +43,7 @@ def get_blur(input): def get_input_data(): - image_path = os.path.join(os.path.dirname(__file__), "../../apps/images/rgb.png") + image_path = os.path.join(apps_images_dir(), "rgb.png") assert os.path.exists(image_path), \ "Could not find %s" % image_path rgb_data = imageio.imread(image_path) @@ -63,8 +75,8 @@ def main(): blur.realize(output_image) # save results - input_path = "blur_input.png" - output_path = "blur_result.png" + input_path = os.path.join(apps_output_dir(), "blur_input.png") + output_path = os.path.join(apps_output_dir(), "blur_result.png") imageio.imsave(input_path, input_data) imageio.imsave(output_path, output_data) print("\nblur realized on output image.", diff --git a/python_bindings/apps/erode.py b/python_bindings/test/apps/erode.py similarity index 75% rename from python_bindings/apps/erode.py rename to python_bindings/test/apps/erode.py index 5a870a45b1f2..95cbf9e6d39f 100644 --- a/python_bindings/apps/erode.py +++ b/python_bindings/test/apps/erode.py @@ -8,6 +8,18 @@ import imageio import os.path +# Return the directory to look in for test images: +# - If TEST_IMAGES_DIR is defined, use that +# - Otherwise, create a relative path to the C++ apps/images dir +def apps_images_dir(): + return os.environ.get("TEST_IMAGES_DIR", os.path.join(os.path.dirname(__file__), "../../apps/images")) + +# Return the directory to use when writing output files: +# - If TEST_TMPDIR is defined, use that +# - Otherwise, return an empty string (i.e., relative to whatever the current directory is) +def apps_output_dir(): + return os.environ.get("TEST_TMPDIR", "") + def get_erode(input): """ Erode on 5x5 stencil, first erode x then erode y. @@ -35,8 +47,7 @@ def get_erode(input): def get_input_data(): - - image_path = os.path.join(os.path.dirname(__file__), "../../apps/images/rgb.png") + image_path = os.path.join(apps_images_dir(), "rgb.png") assert os.path.exists(image_path), \ "Could not find %s" % image_path rgb_data = imageio.imread(image_path) @@ -69,8 +80,8 @@ def main(): erode.realize(output_image) # save results - input_path = "erode_input.png" - output_path = "erode_result.png" + input_path = os.path.join(apps_output_dir(), "erode_input.png") + output_path = os.path.join(apps_output_dir(), "erode_result.png") imageio.imsave(input_path, input_data) imageio.imsave(output_path, output_data) print("\nerode realized on output image.", diff --git a/python_bindings/apps/interpolate.py b/python_bindings/test/apps/interpolate.py similarity index 90% rename from python_bindings/apps/interpolate.py rename to python_bindings/test/apps/interpolate.py index 6db730a58973..d6569d107368 100644 --- a/python_bindings/apps/interpolate.py +++ b/python_bindings/test/apps/interpolate.py @@ -9,6 +9,18 @@ import numpy as np import os.path +# Return the directory to look in for test images: +# - If TEST_IMAGES_DIR is defined, use that +# - Otherwise, create a relative path to the C++ apps/images dir +def apps_images_dir(): + return os.environ.get("TEST_IMAGES_DIR", os.path.join(os.path.dirname(__file__), "../../apps/images")) + +# Return the directory to use when writing output files: +# - If TEST_TMPDIR is defined, use that +# - Otherwise, return an empty string (i.e., relative to whatever the current directory is) +def apps_output_dir(): + return os.environ.get("TEST_TMPDIR", "") + int_t = hl.Int(32) float_t = hl.Float(32) @@ -149,7 +161,7 @@ def get_interpolate(input, levels): def get_input_data(): - image_path = os.path.join(os.path.dirname(__file__), "../../apps/images/rgba.png") + image_path = os.path.join(apps_images_dir(), "rgba.png") assert os.path.exists(image_path), "Could not find %s" % image_path rgba_data = imageio.imread(image_path) @@ -187,8 +199,8 @@ def main(): output_data = (output_data * 255).astype(np.uint8) # save results - input_path = "interpolate_input.png" - output_path = "interpolate_result.png" + input_path = os.path.join(apps_output_dir(), "interpolate_input.png") + output_path = os.path.join(apps_output_dir(), "interpolate_result.png") imageio.imsave(input_path, input_data) imageio.imsave(output_path, output_data) diff --git a/python_bindings/apps/local_laplacian.py b/python_bindings/test/apps/local_laplacian.py similarity index 91% rename from python_bindings/apps/local_laplacian.py rename to python_bindings/test/apps/local_laplacian.py index c0b597d05c78..99d9b4e38061 100644 --- a/python_bindings/apps/local_laplacian.py +++ b/python_bindings/test/apps/local_laplacian.py @@ -8,6 +8,18 @@ import imageio import os.path +# Return the directory to look in for test images: +# - If TEST_IMAGES_DIR is defined, use that +# - Otherwise, create a relative path to the C++ apps/images dir +def apps_images_dir(): + return os.environ.get("TEST_IMAGES_DIR", os.path.join(os.path.dirname(__file__), "../../apps/images")) + +# Return the directory to use when writing output files: +# - If TEST_TMPDIR is defined, use that +# - Otherwise, return an empty string (i.e., relative to whatever the current directory is) +def apps_output_dir(): + return os.environ.get("TEST_TMPDIR", "") + int_t = hl.Int(32) float_t = hl.Float(32) @@ -175,7 +187,7 @@ def upsample2D(f): def get_input_data(): - image_path = os.path.join(os.path.dirname(__file__), "../../apps/images/rgb.png") + image_path = os.path.join(apps_images_dir(), "rgb.png") assert os.path.exists(image_path), "Could not find {}".format(image_path) rgb_data = imageio.imread(image_path) @@ -205,8 +217,8 @@ def filter_test_image(local_laplacian, input): output_data = (output_data >> 8).astype(np.uint8) # save results - input_path = "local_laplacian_input.png" - output_path = "local_laplacian.png" + input_path = os.path.join(apps_output_dir(), "local_laplacian_input.png") + output_path = os.path.join(apps_output_dir(), "local_laplacian.png") imageio.imsave(input_path, input_data) imageio.imsave(output_path, output_data) diff --git a/python_bindings/test/correctness/CMakeLists.txt b/python_bindings/test/correctness/CMakeLists.txt new file mode 100644 index 000000000000..7180fa6d0732 --- /dev/null +++ b/python_bindings/test/correctness/CMakeLists.txt @@ -0,0 +1,35 @@ +add_library(the_sort_function MODULE the_sort_function.c) +target_link_libraries(the_sort_function PRIVATE Halide::Runtime) + +set(tests + addconstant_test.py + atomics.py + autodiff.py + basics.py + bit_test.py + boundary_conditions.py + buffer.py + callable.py + compile_to.py + division.py + extern.py + float_precision_test.py + iroperator.py + multipass_constraints.py + pystub.py + rdom.py + realize_warnings.py + target.py + tuple_select.py + type.py + user_context_test.py + var.py + ) + +foreach (test IN LISTS tests) + add_python_test( + FILE "${test}" + LABEL python_correctness + PYTHONPATH "$" "$" + ) +endforeach () diff --git a/python_bindings/correctness/addconstant_test.py b/python_bindings/test/correctness/addconstant_test.py similarity index 100% rename from python_bindings/correctness/addconstant_test.py rename to python_bindings/test/correctness/addconstant_test.py diff --git a/python_bindings/correctness/atomics.py b/python_bindings/test/correctness/atomics.py similarity index 100% rename from python_bindings/correctness/atomics.py rename to python_bindings/test/correctness/atomics.py diff --git a/python_bindings/correctness/autodiff.py b/python_bindings/test/correctness/autodiff.py similarity index 100% rename from python_bindings/correctness/autodiff.py rename to python_bindings/test/correctness/autodiff.py diff --git a/python_bindings/correctness/basics.py b/python_bindings/test/correctness/basics.py similarity index 79% rename from python_bindings/correctness/basics.py rename to python_bindings/test/correctness/basics.py index b63bf2ff299c..75581378e6d3 100644 --- a/python_bindings/correctness/basics.py +++ b/python_bindings/test/correctness/basics.py @@ -11,7 +11,7 @@ def test_compiletime_error(): buf = hl.Buffer(hl.UInt(8), [2, 2]) try: f.realize(buf) - except RuntimeError as e: + except hl.HalideError as e: assert 'Output buffer f has type uint16 but type of the buffer passed in is uint8' in str(e) else: assert False, 'Did not see expected exception!' @@ -25,7 +25,7 @@ def test_runtime_error(): buf = hl.Buffer(hl.UInt(8), [10]) try: f.realize(buf) - except RuntimeError as e: + except hl.HalideError as e: assert 'do not cover required region' in str(e) else: assert False, 'Did not see expected exception!' @@ -117,7 +117,7 @@ def test_basics2(): try: val1 = clamped[x * s_sigma - s_sigma/2, y * s_sigma - s_sigma/2] - except RuntimeError as e: + except hl.HalideError as e: assert 'Implicit cast from float32 to int' in str(e) else: assert False, 'Did not see expected exception!' @@ -309,11 +309,88 @@ def test_bool_conversion(): # Verify that this doesn't fail with 'Argument passed to specialize must be of type bool' f.compute_root().specialize(True) +def test_typed_funcs(): + x = hl.Var('x') + y = hl.Var('y') + + f = hl.Func('f') + assert not f.defined() + try: + assert f.type() == Int(32) + except hl.HalideError as e: + assert 'it is undefined' in str(e) + else: + assert False, 'Did not see expected exception!' + + try: + assert f.outputs() == 0 + except hl.HalideError as e: + assert 'it is undefined' in str(e) + else: + assert False, 'Did not see expected exception!' + + try: + assert f.dimensions() == 0 + except hl.HalideError as e: + assert 'it is undefined' in str(e) + else: + assert False, 'Did not see expected exception!' + + + f = hl.Func(hl.Int(32), 2, 'f') + assert not f.defined() + assert f.type() == hl.Int(32) + assert f.types() == [hl.Int(32)] + assert f.outputs() == 1 + assert f.dimensions() == 2 + + f = hl.Func([hl.Int(32), hl.Float(64)], 3, 'f') + assert not f.defined() + try: + assert f.type() == hl.Int(32) + except hl.HalideError as e: + assert 'it returns a Tuple' in str(e) + else: + assert False, 'Did not see expected exception!' + + assert f.types() == [hl.Int(32), hl.Float(64)] + assert f.outputs() == 2 + assert f.dimensions() == 3 + + f = hl.Func(hl.Int(32), 1, 'f') + try: + f[x, y] = hl.i32(0); + f.realize([10, 10]) + except hl.HalideError as e: + assert 'is constrained to have exactly 1 dimensions, but is defined with 2 dimensions' in str(e) + else: + assert False, 'Did not see expected exception!' + + f = hl.Func(hl.Int(32), 2, 'f') + try: + f[x, y] = hl.i16(0); + f.realize([10, 10]) + except hl.HalideError as e: + assert 'is constrained to only hold values of type int32 but is defined with values of type int16' in str(e) + else: + assert False, 'Did not see expected exception!' + + f = hl.Func((hl.Int(32), hl.Float(32)), 2, 'f') + try: + f[x, y] = (hl.i16(0), hl.f64(0)) + f.realize([10, 10]) + except hl.HalideError as e: + assert 'is constrained to only hold values of type (int32, float32) but is defined with values of type (int16, float64)' in str(e) + else: + assert False, 'Did not see expected exception!' + + if __name__ == "__main__": test_compiletime_error() test_runtime_error() test_misused_and() test_misused_or() + test_typed_funcs() test_float_or_int() test_operator_order() test_int_promotion() diff --git a/python_bindings/correctness/bit_test.py b/python_bindings/test/correctness/bit_test.py similarity index 100% rename from python_bindings/correctness/bit_test.py rename to python_bindings/test/correctness/bit_test.py diff --git a/python_bindings/correctness/boundary_conditions.py b/python_bindings/test/correctness/boundary_conditions.py similarity index 100% rename from python_bindings/correctness/boundary_conditions.py rename to python_bindings/test/correctness/boundary_conditions.py diff --git a/python_bindings/correctness/buffer.py b/python_bindings/test/correctness/buffer.py similarity index 91% rename from python_bindings/correctness/buffer.py rename to python_bindings/test/correctness/buffer.py index 11ac31e49729..10196300eb6e 100644 --- a/python_bindings/correctness/buffer.py +++ b/python_bindings/test/correctness/buffer.py @@ -130,6 +130,20 @@ def test_float16(): hl_img = hl.Buffer(array_in) array_out = np.array(hl_img, copy = False) +# TODO: https://github.com/halide/Halide/issues/6849 +# def test_bfloat16(): +# try: +# from tensorflow.python.lib.core import _pywrap_bfloat16 +# bfloat16 = _pywrap_bfloat16.TF_bfloat16_type() +# array_in = np.zeros((256, 256, 3), dtype=bfloat16, order='F') +# hl_img = hl.Buffer(array_in) +# array_out = np.array(hl_img, copy = False) +# except ModuleNotFoundError as e: +# print("skipping test_bfloat16() because tensorflow was not found: %s" % str(e)) +# return +# else: +# assert False, "This should not happen" + def test_int64(): array_in = np.zeros((256, 256, 3), dtype=np.int64, order='F') hl_img = hl.Buffer(array_in) @@ -250,7 +264,7 @@ def test_overflow(): try: hl.Buffer(size_over_intmax) except ValueError as e: - assert 'Out of range arguments to make_dim_vec.' in str(e) + assert 'Out of range dimensions in buffer conversion' in str(e) def test_buffer_to_str(): b = hl.Buffer() @@ -279,6 +293,8 @@ def test_scalar_buffers(): test_for_each_element() test_fill_all_equal() test_bufferinfo_sharing() + # TODO: https://github.com/halide/Halide/issues/6849 + # test_bfloat16() test_float16() test_int64() test_reorder() diff --git a/python_bindings/test/correctness/callable.py b/python_bindings/test/correctness/callable.py new file mode 100644 index 000000000000..93016743ebdc --- /dev/null +++ b/python_bindings/test/correctness/callable.py @@ -0,0 +1,177 @@ +import halide as hl +import numpy as np + +import simple_pystub # Needed for create_callable_from_generator("simple") to work + +def test_callable(): + p_int16 = hl.Param(hl.Int(16), 42) + p_float = hl.Param(hl.Float(32), 1.0) + p_img = hl.ImageParam(hl.UInt(8), 2) + + x = hl.Var('x') + y = hl.Var('y') + f = hl.Func('f') + + f[x, y] = p_img[x, y] + hl.u8(p_int16 / p_float) + + in1 = hl.Buffer(hl.UInt(8), [10, 10]) + in2 = hl.Buffer(hl.UInt(8), [10, 10]) + + for i in range(10): + for j in range(10): + in1[i, j] = i + j * 10 + in2[i, j] = i * 10 + j + + c = f.compile_to_callable([p_img, p_int16, p_float]); + + out1 = hl.Buffer(hl.UInt(8), [10, 10]) + c(in1, 42, 1.0, out1) + + out2 = hl.Buffer(hl.UInt(8), [10, 10]) + c(in2, 22, 2.0, out2) + + out3 = hl.Buffer(hl.UInt(8), [10, 10]) + c(in1, 12, 1.0, out3) + + out4 = hl.Buffer(hl.UInt(8), [10, 10]) + c(in2, 16, 1.0, out4) + + for i in range(10): + for j in range(10): + assert out1[i, j] == i + j * 10 + 42 + assert out2[i, j] == i * 10 + j + 11 + assert out3[i, j] == i + j * 10 + 12 + assert out4[i, j] == i * 10 + j + 16 + + # Test bounds inference. Note that in Python there + # isn't a "natural" way to create a buffer with a null host ptr + # so we use this specific API for the purpose. + in_bounds = hl.Buffer.make_bounds_query(hl.UInt(8), [1, 1]) + out_bounds = hl.Buffer.make_bounds_query(hl.UInt(8), [20, 20]) + c(in_bounds, 42, 1.0, out_bounds) + + assert in_bounds.defined() + assert in_bounds.dim(0).extent() == 20 + assert in_bounds.dim(1).extent() == 20 + assert in1.dim(0).extent() == 10 + assert in1.dim(1).extent() == 10 + +def test_simple(): + x, y = hl.Var(), hl.Var() + target = hl.get_jit_target_from_environment() + + b_in = hl.Buffer(hl.UInt(8), [2, 2]) + b_in.fill(123) + + # All inputs to a Callable must be fully realized, so any Func inputs + # that the Generator has implicitly become Buffer inputs of the same type + # and dimensionality. + f_in = hl.Buffer(hl.Int(32), [2, 2]) + for xx in range(2): + for yy in range(2): + f_in[xx, yy] = xx + yy + + float_in = 3.5 + + b_out = hl.Buffer(hl.Float(32), [2, 2]) + + def _check(offset = 0): + assert b_out[0, 0] == float_in + 0 + offset + 123 + assert b_out[0, 1] == float_in + 1 + offset + 123 + assert b_out[1, 0] == float_in + 1 + offset + 123 + assert b_out[1, 1] == float_in + 2 + offset + 123 + + gp = {"func_input.type": "int32"} + simple = hl.create_callable_from_generator(target, "simple", gp) + + # ----------- Positional arguments + simple(b_in, f_in, float_in, b_out) + _check() + + # ----------- Keyword arguments + # Natural order + simple(buffer_input=b_in, func_input=f_in, float_arg=float_in, simple_output=b_out) + _check() + + # Weird order + simple(float_arg=float_in, simple_output=b_out, buffer_input=b_in, func_input=f_in) + _check() + + # ----------- Positional + Keywords + + # Natural order + simple(b_in, func_input=f_in, simple_output=b_out, float_arg=float_in) + _check() + + # Weird order + simple(b_in, f_in, float_in, simple_output=b_out) + _check() + + # ----------- Above set again, w/ additional GeneratorParam mixed in + k = 42 + + gp = {"func_input.type": "int32", "offset": str(k)} + simple_42 = hl.create_callable_from_generator(target, "simple", gp) + simple_42(b_in, f_in, float_in, b_out) + _check(k) + + # ----------- Test various failure modes + try: + # too many positional args + simple(b_in, f_in, float_in, 4, b_out) + except hl.HalideError as e: + assert 'Expected at most 4 positional arguments, but saw 5.' in str(e) + else: + assert False, 'Did not see expected exception!' + + try: + # too few positional args + simple(b_in, f_in) + except hl.HalideError as e: + assert 'Expected exactly 4 positional arguments, but saw 2.' in str(e) + else: + assert False, 'Did not see expected exception!' + + try: + # Inputs that can't be converted to what the receiver needs (positional) + simple(hl.f32(3.141592), "happy", k, b_out) + except hl.HalideError as e: + assert 'is not an instance of' in str(e) + else: + assert False, 'Did not see expected exception!' + + try: + # Inputs that can't be converted to what the receiver needs (named) + simple(b_in, f_in, float_in, simple_output="bogus") + except hl.HalideError as e: + assert 'is not an instance of' in str(e) + else: + assert False, 'Did not see expected exception!' + + try: + # Bad keyword argument + simple(buffer_input=b_in, float_arg=float_in, simple_output=b_out, funk_input=f_in) + except hl.HalideError as e: + assert "Unknown argument 'funk_input' specified via keyword." in str(e) + else: + assert False, 'Did not see expected exception!' + + try: + # too few keyword args + simple(float_arg=float_in, simple_output=b_out, func_input=f_in) + except hl.HalideError as e: + assert 'Argument buffer_input was not specified by either positional or keyword argument.' in str(e) + else: + assert False, 'Did not see expected exception!' + + try: + # Arg specified by pos + kw + simple(b_in, buffer_input=b_in, func_input=f_in, float_arg=float_in, simple_output=b_out) + except hl.HalideError as e: + assert 'Argument buffer_input specified multiple times.' in str(e) + else: + assert False, 'Did not see expected exception!' + +if __name__ == "__main__": + test_callable() + test_simple() diff --git a/python_bindings/correctness/compile_to.py b/python_bindings/test/correctness/compile_to.py similarity index 100% rename from python_bindings/correctness/compile_to.py rename to python_bindings/test/correctness/compile_to.py diff --git a/python_bindings/correctness/division.py b/python_bindings/test/correctness/division.py similarity index 100% rename from python_bindings/correctness/division.py rename to python_bindings/test/correctness/division.py diff --git a/python_bindings/correctness/extern.py b/python_bindings/test/correctness/extern.py similarity index 85% rename from python_bindings/correctness/extern.py rename to python_bindings/test/correctness/extern.py index d7836b3691ea..2b8d8c52dad6 100644 --- a/python_bindings/correctness/extern.py +++ b/python_bindings/test/correctness/extern.py @@ -32,10 +32,10 @@ def test_extern(): try: sort_func.compile_jit() - except RuntimeError: - pass + except hl.HalideError: + assert 'cannot be converted to a bool' in str(e) else: - raise Exception("compile_jit should have raised a 'Symbol not found' RuntimeError") + assert False, 'Did not see expected exception!' import ctypes @@ -44,10 +44,10 @@ def test_extern(): try: sort_func.compile_jit() - except RuntimeError: - print("ctypes CDLL did not work out") + except hl.HalideError: + assert 'cannot be converted to a bool' in str(e) else: - print("ctypes CDLL worked !") + assert False, 'Did not see expected exception!' lib_path = "the_sort_function.so" #lib_path = "/home/rodrigob/code/references/" \ diff --git a/python_bindings/correctness/float_precision_test.py b/python_bindings/test/correctness/float_precision_test.py similarity index 100% rename from python_bindings/correctness/float_precision_test.py rename to python_bindings/test/correctness/float_precision_test.py diff --git a/python_bindings/correctness/iroperator.py b/python_bindings/test/correctness/iroperator.py similarity index 100% rename from python_bindings/correctness/iroperator.py rename to python_bindings/test/correctness/iroperator.py diff --git a/python_bindings/correctness/multipass_constraints.py b/python_bindings/test/correctness/multipass_constraints.py similarity index 100% rename from python_bindings/correctness/multipass_constraints.py rename to python_bindings/test/correctness/multipass_constraints.py diff --git a/python_bindings/correctness/negate_test.py b/python_bindings/test/correctness/negate_test.py similarity index 100% rename from python_bindings/correctness/negate_test.py rename to python_bindings/test/correctness/negate_test.py diff --git a/python_bindings/correctness/pystub.py b/python_bindings/test/correctness/pystub.py similarity index 66% rename from python_bindings/correctness/pystub.py rename to python_bindings/test/correctness/pystub.py index 280de4a5359c..67941193a80b 100644 --- a/python_bindings/correctness/pystub.py +++ b/python_bindings/test/correctness/pystub.py @@ -1,9 +1,8 @@ import halide as hl -import simplestub -# test alternate-but-legal syntax -from complexstub import generate as complexstub +import simple_pystub +import complex_pystub def _realize_and_check(f, offset = 0): b = hl.Buffer(hl.Float(32), [2, 2]) @@ -15,7 +14,7 @@ def _realize_and_check(f, offset = 0): assert b[1, 1] == 5.5 + offset + 123 -def test_simplestub(): +def test_simple(gen): x, y = hl.Var(), hl.Var() target = hl.get_jit_target_from_environment() @@ -26,102 +25,120 @@ def test_simplestub(): f_in[x, y] = x + y # ----------- Inputs by-position - f = simplestub.generate(target, b_in, f_in, 3.5) + f = gen(target, b_in, f_in, 3.5) _realize_and_check(f) # ----------- Inputs by-name - f = simplestub.generate(target, buffer_input=b_in, func_input=f_in, float_arg=3.5) + f = gen(target, buffer_input=b_in, func_input=f_in, float_arg=3.5) _realize_and_check(f) - f = simplestub.generate(target, float_arg=3.5, buffer_input=b_in, func_input=f_in) + f = gen(target, float_arg=3.5, buffer_input=b_in, func_input=f_in) _realize_and_check(f) # ----------- Above set again, w/ GeneratorParam mixed in k = 42 + gp = { "offset": k } + # (positional) - f = simplestub.generate(target, b_in, f_in, 3.5, offset=k) + f = gen(target, b_in, f_in, 3.5, generator_params=gp) _realize_and_check(f, k) # (keyword) - f = simplestub.generate(target, offset=k, buffer_input=b_in, func_input=f_in, float_arg=3.5) + f = gen(target, generator_params=gp, buffer_input=b_in, func_input=f_in, float_arg=3.5) _realize_and_check(f, k) - f = simplestub.generate(target, buffer_input=b_in, offset=k, func_input=f_in, float_arg=3.5) + f = gen(target, buffer_input=b_in, generator_params=gp, func_input=f_in, float_arg=3.5) _realize_and_check(f, k) - f = simplestub.generate(target, buffer_input=b_in, func_input=f_in, offset=k, float_arg=3.5) + f = gen(target, buffer_input=b_in, func_input=f_in, generator_params=gp, float_arg=3.5) _realize_and_check(f, k) - f = simplestub.generate(target, buffer_input=b_in, float_arg=3.5, func_input=f_in, offset=k) + f = gen(target, buffer_input=b_in, float_arg=3.5, func_input=f_in, generator_params=gp) _realize_and_check(f, k) # ----------- Test various failure modes try: # Inputs w/ mixed by-position and by-name - f = simplestub.generate(target, b_in, f_in, float_arg=3.5) - except RuntimeError as e: + f = gen(target, b_in, f_in, float_arg=3.5) + except hl.HalideError as e: assert 'Cannot use both positional and keyword arguments for inputs.' in str(e) else: assert False, 'Did not see expected exception!' try: # too many positional args - f = simplestub.generate(target, b_in, f_in, 3.5, 4) - except RuntimeError as e: + f = gen(target, b_in, f_in, 3.5, 4) + except hl.HalideError as e: assert 'Expected exactly 3 positional args for inputs, but saw 4.' in str(e) else: assert False, 'Did not see expected exception!' try: # too few positional args - f = simplestub.generate(target, b_in, f_in) - except RuntimeError as e: + f = gen(target, b_in, f_in) + except hl.HalideError as e: assert 'Expected exactly 3 positional args for inputs, but saw 2.' in str(e) else: assert False, 'Did not see expected exception!' try: # Inputs that can't be converted to what the receiver needs (positional) - f = simplestub.generate(target, hl.f32(3.141592), "happy", k) - except RuntimeError as e: - assert 'Unable to cast Python instance' in str(e) + f = gen(target, hl.f32(3.141592), "happy", k) + except hl.HalideError as e: + assert 'Input buffer_input requires an ImageParam or Buffer argument' in str(e) else: assert False, 'Did not see expected exception!' try: # Inputs that can't be converted to what the receiver needs (named) - f = simplestub.generate(target, b_in, f_in, float_arg="bogus") - except RuntimeError as e: - assert 'Unable to cast Python instance' in str(e) + f = gen(target, b_in, f_in, float_arg="bogus") + except hl.HalideError as e: + assert 'Input float_arg requires a Param (or scalar literal) argument' in str(e) else: assert False, 'Did not see expected exception!' try: # Input specified by both pos and kwarg - f = simplestub.generate(target, b_in, f_in, 3.5, float_arg=4.5) - except RuntimeError as e: + f = gen(target, b_in, f_in, 3.5, float_arg=4.5) + except hl.HalideError as e: assert "Cannot use both positional and keyword arguments for inputs." in str(e) else: assert False, 'Did not see expected exception!' + try: + # generator_params is not a dict + f = gen(target, b_in, f_in, 3.5, generator_params=[1, 2, 3]) + except TypeError as e: + assert "cannot convert dictionary" in str(e) + else: + assert False, 'Did not see expected exception!' + + try: + # Bad gp name + f = gen(target, b_in, f_in, 3.5, generator_params={"foo": 0}) + except hl.HalideError as e: + assert "has no GeneratorParam named: foo" in str(e) + else: + assert False, 'Did not see expected exception!' + try: # Bad input name - f = simplestub.generate(target, buffer_input=b_in, float_arg=3.5, offset=k, funk_input=f_in) - except RuntimeError as e: - assert "Expected exactly 3 keyword args for inputs, but saw 2." in str(e) + f = gen(target, buffer_input=b_in, float_arg=3.5, generator_params=gp, funk_input=f_in) + except hl.HalideError as e: + assert "Unknown input 'funk_input' specified via keyword argument." in str(e) else: assert False, 'Did not see expected exception!' try: # Bad gp name - f = simplestub.generate(target, buffer_input=b_in, float_arg=3.5, offset=k, func_input=f_in, nonexistent_generator_param="wat") - except RuntimeError as e: - assert "Generator simplestub has no GeneratorParam named: nonexistent_generator_param" in str(e) + f = gen(target, buffer_input=b_in, float_arg=3.5, generator_params=gp, func_input=f_in, nonexistent_generator_param="wat") + except hl.HalideError as e: + assert "Unknown input 'nonexistent_generator_param' specified via keyword argument." in str(e) else: assert False, 'Did not see expected exception!' -def test_looplevel(): +def test_looplevel(gen): x, y = hl.Var('x'), hl.Var('y') target = hl.get_jit_target_from_environment() @@ -132,8 +149,10 @@ def test_looplevel(): func_input[x, y] = x + y simple_compute_at = hl.LoopLevel() - simple = simplestub.generate(target, buffer_input, func_input, 3.5, - compute_level=simple_compute_at) + simple = gen(target, buffer_input, func_input, 3.5, + generator_params = { + "compute_level": simple_compute_at + }) computed_output = hl.Func('computed_output') computed_output[x, y] = simple[x, y] + 3 @@ -151,7 +170,7 @@ def _make_constant_image(): constant_image[x, y, c] = x + y + c return constant_image -def test_complexstub(): +def test_complex(gen): constant_image = _make_constant_image() input = hl.ImageParam(hl.UInt(8), 3, 'input') input.set(constant_image) @@ -165,16 +184,18 @@ def test_complexstub(): func_input = hl.Func("func_input") func_input[x, y, c] = hl.u16(x + y + c) - r = complexstub(target, - typed_buffer_input=constant_image, - untyped_buffer_input=constant_image, - simple_input=input, - array_input=[ input, input ], - float_arg=float_arg, - int_arg=[ int_arg, int_arg ], - untyped_buffer_output_type="uint8", - extra_func_input=func_input, - vectorize=True) + r = gen(target, + typed_buffer_input=constant_image, + untyped_buffer_input=constant_image, + simple_input=input, + array_input=[ input, input ], + float_arg=float_arg, + int_arg=[ int_arg, int_arg ], + extra_func_input=func_input, + generator_params = { + "untyped_buffer_output.type": hl.UInt(8), + "vectorize": True + }) # return value is a tuple; unpack separately to avoid # making the callsite above unreadable @@ -184,6 +205,7 @@ def test_complexstub(): typed_buffer_output, untyped_buffer_output, static_compiled_buffer_output, + scalar_output, extra_func_output) = r b = simple_output.realize([32, 32, 3], target) @@ -249,6 +271,10 @@ def test_complexstub(): actual = b[x, y, c] assert expected == actual, "Expected %s Actual %s" % (expected, actual) + b = scalar_output.realize([], target) + assert b.type() == hl.Float(32) + assert b[()] == 34.25 + b = extra_func_output.realize([32, 32], target) assert b.type() == hl.Float(64) for x in range(32): @@ -258,6 +284,6 @@ def test_complexstub(): assert expected == actual, "Expected %s Actual %s" % (expected, actual) if __name__ == "__main__": - test_simplestub() - test_looplevel() - test_complexstub() + test_simple(simple_pystub.generate) + test_looplevel(simple_pystub.generate) + test_complex(complex_pystub.generate) diff --git a/python_bindings/correctness/rdom.py b/python_bindings/test/correctness/rdom.py similarity index 100% rename from python_bindings/correctness/rdom.py rename to python_bindings/test/correctness/rdom.py diff --git a/python_bindings/correctness/realize_warnings.py b/python_bindings/test/correctness/realize_warnings.py similarity index 100% rename from python_bindings/correctness/realize_warnings.py rename to python_bindings/test/correctness/realize_warnings.py diff --git a/python_bindings/correctness/target.py b/python_bindings/test/correctness/target.py similarity index 100% rename from python_bindings/correctness/target.py rename to python_bindings/test/correctness/target.py diff --git a/python_bindings/correctness/the_sort_function.c b/python_bindings/test/correctness/the_sort_function.c similarity index 100% rename from python_bindings/correctness/the_sort_function.c rename to python_bindings/test/correctness/the_sort_function.c diff --git a/python_bindings/correctness/tuple_select.py b/python_bindings/test/correctness/tuple_select.py similarity index 97% rename from python_bindings/correctness/tuple_select.py rename to python_bindings/test/correctness/tuple_select.py index 7a7b3fee3e4b..653c2420c5be 100644 --- a/python_bindings/correctness/tuple_select.py +++ b/python_bindings/test/correctness/tuple_select.py @@ -64,7 +64,7 @@ def test_tuple_select(): f[x, y] = hl.tuple_select((x < 30, y < 30), (x, y), x + y < 100, (x-1, y-2), (x-100, y-200)) - except RuntimeError as e: + except hl.HalideError as e: assert 'tuple_select() may not mix Expr and Tuple for the condition elements.' in str(e) else: assert False, 'Did not see expected exception!' @@ -73,7 +73,7 @@ def test_tuple_select(): try: f = hl.Func('f') f[x, y] = hl.tuple_select((x < 30, y < 30), (x, y, 0), (1, 2, 3, 4)) - except RuntimeError as e: + except hl.HalideError as e: assert 'tuple_select() requires all Tuples to have identical sizes' in str(e) else: assert False, 'Did not see expected exception!' diff --git a/python_bindings/correctness/type.py b/python_bindings/test/correctness/type.py similarity index 100% rename from python_bindings/correctness/type.py rename to python_bindings/test/correctness/type.py diff --git a/python_bindings/correctness/user_context_test.py b/python_bindings/test/correctness/user_context_test.py similarity index 100% rename from python_bindings/correctness/user_context_test.py rename to python_bindings/test/correctness/user_context_test.py diff --git a/python_bindings/correctness/var.py b/python_bindings/test/correctness/var.py similarity index 100% rename from python_bindings/correctness/var.py rename to python_bindings/test/correctness/var.py diff --git a/python_bindings/test/generators/CMakeLists.txt b/python_bindings/test/generators/CMakeLists.txt new file mode 100644 index 000000000000..8fcfb0734601 --- /dev/null +++ b/python_bindings/test/generators/CMakeLists.txt @@ -0,0 +1,40 @@ +include(PythonExtensionHelpers) + +set(GENERATORS + addconstant + bit + complex + simple + user_context + ) + +# Some Generators require extra Halide Target Features to be set. +set(FEATURES_user_context user_context) + +# Some Generators have undefined types, sizes, etc that are useful for Stubs extensions, +# but unacceptable for AOT Extensions; ensure that all of those are explicitly +# specified for AOT. (We currently don't use or test these in AOT form, so the settings +# are somewhat arbitrary.) +set(GENPARAMS_complex + array_input.size=2 + array_input.type=uint8 + int_arg.size=2 + simple_input.type=uint8 + untyped_buffer_input.type=uint8 + untyped_buffer_output.type=uint8) + +set(GENPARAMS_simple + func_input.type=uint8) + +foreach (GEN IN LISTS GENERATORS) + add_python_aot_extension(py_aot_${GEN} + GENERATOR ${GEN} + FEATURES ${FEATURES_${GEN}} + PARAMS ${GENPARAMS_${GEN}} + SOURCES ${GEN}_generator.cpp) + + add_python_stub_extension(py_stub_${GEN} + SOURCES ${GEN}_generator.cpp + GENERATOR ${GEN} + MODULE ${GEN}_pystub) +endforeach () diff --git a/python_bindings/correctness/addconstant_generator.cpp b/python_bindings/test/generators/addconstant_generator.cpp similarity index 100% rename from python_bindings/correctness/addconstant_generator.cpp rename to python_bindings/test/generators/addconstant_generator.cpp diff --git a/python_bindings/correctness/bit_generator.cpp b/python_bindings/test/generators/bit_generator.cpp similarity index 100% rename from python_bindings/correctness/bit_generator.cpp rename to python_bindings/test/generators/bit_generator.cpp diff --git a/python_bindings/correctness/complexstub_generator.cpp b/python_bindings/test/generators/complex_generator.cpp similarity index 93% rename from python_bindings/correctness/complexstub_generator.cpp rename to python_bindings/test/generators/complex_generator.cpp index 1b6e3274a806..f39cc9e4a31b 100644 --- a/python_bindings/correctness/complexstub_generator.cpp +++ b/python_bindings/test/generators/complex_generator.cpp @@ -15,9 +15,8 @@ Halide::Buffer make_image(int extra) { return im; } -class ComplexStub : public Halide::Generator { +class Complex : public Halide::Generator { public: - GeneratorParam untyped_buffer_output_type{"untyped_buffer_output_type", Float(32)}; GeneratorParam vectorize{"vectorize", true}; GeneratorParam intermediate_level{"intermediate_level", LoopLevel::root()}; @@ -35,6 +34,7 @@ class ComplexStub : public Halide::Generator { Output> typed_buffer_output{"typed_buffer_output"}; Output> untyped_buffer_output{"untyped_buffer_output"}; Output> static_compiled_buffer_output{"static_compiled_buffer_output"}; + Output scalar_output{"scalar_output"}; void configure() { // Pointers returned by add_input() are managed by the Generator; @@ -51,7 +51,7 @@ class ComplexStub : public Halide::Generator { // assert-fail, because there is no type constraint set: the type // will end up as whatever we infer from the values put into it. We'll use an // explicit GeneratorParam to allow us to set it. - untyped_buffer_output(x, y, c) = cast(untyped_buffer_output_type, untyped_buffer_input(x, y, c)); + untyped_buffer_output(x, y, c) = cast(untyped_buffer_output.type(), untyped_buffer_input(x, y, c)); // Gratuitous intermediate for the purpose of exercising // GeneratorParam @@ -72,6 +72,8 @@ class ComplexStub : public Halide::Generator { static_compiled_buffer_output = static_compiled_buffer; (*extra_func_output)(x, y) = cast((*extra_func_input)(x, y, 0) + 1); + + scalar_output() = float_arg + int_arg; } void schedule() { @@ -90,4 +92,4 @@ class ComplexStub : public Halide::Generator { } // namespace -HALIDE_REGISTER_GENERATOR(ComplexStub, complexstub) +HALIDE_REGISTER_GENERATOR(Complex, complex) diff --git a/python_bindings/correctness/simplestub_generator.cpp b/python_bindings/test/generators/simple_generator.cpp similarity index 87% rename from python_bindings/correctness/simplestub_generator.cpp rename to python_bindings/test/generators/simple_generator.cpp index c7f4e56b7de7..759d6184b627 100644 --- a/python_bindings/correctness/simplestub_generator.cpp +++ b/python_bindings/test/generators/simple_generator.cpp @@ -2,7 +2,7 @@ namespace { -class SimpleStub : public Halide::Generator { +class Simple : public Halide::Generator { public: GeneratorParam offset{"offset", 0}; GeneratorParam compute_level{"compute_level", LoopLevel::root()}; @@ -27,4 +27,4 @@ class SimpleStub : public Halide::Generator { } // namespace -HALIDE_REGISTER_GENERATOR(SimpleStub, simplestub) +HALIDE_REGISTER_GENERATOR(Simple, simple) diff --git a/python_bindings/correctness/user_context_generator.cpp b/python_bindings/test/generators/user_context_generator.cpp similarity index 100% rename from python_bindings/correctness/user_context_generator.cpp rename to python_bindings/test/generators/user_context_generator.cpp diff --git a/python_bindings/todo.txt b/python_bindings/todo.txt index dfb2bdb780bb..c73685c38443 100644 --- a/python_bindings/todo.txt +++ b/python_bindings/todo.txt @@ -25,7 +25,6 @@ - InlineReductions - IROperator - LoopLevel - - MachineParams - Module - OutputImageParam - Pipeline diff --git a/python_bindings/tutorial/CMakeLists.txt b/python_bindings/tutorial/CMakeLists.txt index e605282e5f0a..3c2bcb674060 100644 --- a/python_bindings/tutorial/CMakeLists.txt +++ b/python_bindings/tutorial/CMakeLists.txt @@ -1,4 +1,4 @@ -set(TESTS +set(tests lesson_01_basics.py lesson_02_input_image.py lesson_03_debugging_1.py @@ -16,48 +16,72 @@ set(TESTS lesson_14_types.py ) -make_shell_path(PYTHONPATH "$" "$") +set(PYPATH_lesson_10_aot_compilation_run "$") -foreach (TEST IN LISTS TESTS) - get_filename_component(TEST_NAME ${TEST} NAME_WE) - add_test(NAME python_tutorial_${TEST_NAME} - COMMAND Python3::Interpreter "$") +foreach (test IN LISTS tests) + if (TARGET_WEBASSEMBLY AND Halide_TARGET MATCHES "wasm" AND test MATCHES "lesson_10") + message(WARNING "Not all tutorials build under WASM.") + continue() + endif () - set_tests_properties(python_tutorial_${TEST_NAME} PROPERTIES - LABELS python - ENVIRONMENT "PYTHONPATH=${PYTHONPATH};HL_TARGET=${Halide_TARGET}") + cmake_path(GET test STEM test_name) + add_python_test( + FILE "${test}" + LABEL python_tutorial + PYTHONPATH "${PYPATH_${test_name}}" + ) endforeach () -## Add some hacks for getting CMake to delay compiling lesson_10_halide until after the test has run. The "better" way -## of doing this might be to treat lesson 10 like an app and give it its own CMakeLists.txt, but since this is a one-off -## it is probably less maintenance work to do it like this. +if (TARGET_WEBASSEMBLY AND Halide_TARGET MATCHES "wasm") + message(WARNING "Not all tutorials build under WASM.") +else () + ## Add some hacks for getting CMake to delay compiling lesson_10_halide until after the test has run. The "better" way + ## of doing this might be to treat lesson 10 like an app and give it its own CMakeLists.txt, but since this is a one-off + ## it is probably less maintenance work to do it like this. + + # Note that the following tests that are mentioned below were created in the above foreach loop: + # 1. python_tutorial_lesson_10_aot_compilation_generate + # 2. python_tutorial_lesson_10_aot_compilation_run + # The test `python_tutorial_lesson_10_compile` below is responsible for running (1) in service of (2). + + # This dummy command "generates" the files that the TEST python_tutorial_lesson_10_aot_compilation_generate will + # actually generate as part of the fixture below. + add_custom_command(OUTPUT lesson_10_halide.py.cpp lesson_10_halide.o + COMMAND "${CMAKE_COMMAND}" -E echo Dummy command for lesson 10 sources. + VERBATIM) + + # This target allows CMake to build lesson_10_halide.so (or whatever the correct extension is) as part of the tests + # later. It is excluded from ALL since it isn't valid to build outside of this context. + Python3_add_library(lesson_10_halide MODULE EXCLUDE_FROM_ALL + lesson_10_halide.py.cpp + lesson_10_halide.o) + + target_link_libraries(lesson_10_halide PRIVATE Halide::Runtime) -# This dummy command "generates" the files that the TEST python_tutorial_lesson_10_aot_compilation_generate will -# actually generate as part of the fixture below. -add_custom_command(OUTPUT lesson_10_halide.py.cpp lesson_10_halide.o - COMMAND "${CMAKE_COMMAND}" -E echo Dummy command for lesson 10 sources. - VERBATIM) + # The fixture "py_lesson_10" orchestrates running the generator part of the lesson first, then the build for the + # library, and finally runs python_tutorial_lesson_10_aot_compilation_run. The ..._compile test invokes CMake on + # the current build for the above library. + add_test(NAME python_tutorial_lesson_10_compile + COMMAND "${CMAKE_COMMAND}" --build "${CMAKE_BINARY_DIR}" --config $ --target lesson_10_halide) -# This target allows CMake to build lesson_10_halide.so (or whatever the correct extension is) as part of the tests -# later. It is excluded from ALL since it isn't valid to build outside of this context. -Python3_add_library(lesson_10_halide MODULE EXCLUDE_FROM_ALL - lesson_10_halide.py.cpp - lesson_10_halide.o) + set_tests_properties(python_tutorial_lesson_10_aot_compilation_generate PROPERTIES + FIXTURES_SETUP py_lesson_10) -target_link_libraries(lesson_10_halide PRIVATE Halide::Runtime) + set_tests_properties(python_tutorial_lesson_10_compile PROPERTIES + FIXTURES_SETUP py_lesson_10 + DEPENDS python_tutorial_lesson_10_aot_compilation_generate) -# The fixture "py_lesson_10" orchestrates running the generator part of the lesson first, then the build for the -# library, and finally runs python_tutorial_lesson_10_aot_compilation_run. The ..._compile test invokes CMake on -# the current build for the above library. -add_test(NAME python_tutorial_lesson_10_compile - COMMAND "${CMAKE_COMMAND}" --build "${CMAKE_BINARY_DIR}" --config $ --target lesson_10_halide) + set_tests_properties(python_tutorial_lesson_10_aot_compilation_run PROPERTIES + FIXTURES_REQUIRED py_lesson_10) +endif () -set_tests_properties(python_tutorial_lesson_10_aot_compilation_generate PROPERTIES - FIXTURES_SETUP py_lesson_10) +## +# Packaging +## -set_tests_properties(python_tutorial_lesson_10_compile PROPERTIES - FIXTURES_SETUP py_lesson_10 - DEPENDS python_tutorial_lesson_10_aot_compilation_generate) +include(GNUInstallDirs) -set_tests_properties(python_tutorial_lesson_10_aot_compilation_run PROPERTIES - FIXTURES_REQUIRED py_lesson_10) +install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/ + DESTINATION ${CMAKE_INSTALL_DOCDIR}/tutorial-python + COMPONENT Halide_Documentation + FILES_MATCHING PATTERN "*.py") diff --git a/python_bindings/tutorial/lesson_14_types.py b/python_bindings/tutorial/lesson_14_types.py index 927995ff3694..b61a7b90ddf1 100644 --- a/python_bindings/tutorial/lesson_14_types.py +++ b/python_bindings/tutorial/lesson_14_types.py @@ -64,11 +64,11 @@ def main(): # You can also query any defined hl.Func for the types it produces. f1 = hl.Func("f1") f1[x] = hl.cast(hl.UInt(8), x) - assert f1.output_types()[0] == hl.UInt(8) + assert f1.types()[0] == hl.UInt(8) f2 = hl.Func("f2") f2[x] = (x, hl.sin(x)) - assert f2.output_types()[0] == hl.Int(32) and f2.output_types()[1] == hl.Float(32) + assert f2.types()[0] == hl.Int(32) and f2.types()[1] == hl.Float(32) # Type promotion rules. if True: diff --git a/run-clang-format.sh b/run-clang-format.sh index a09cfc27e824..66682418dfc8 100755 --- a/run-clang-format.sh +++ b/run-clang-format.sh @@ -4,23 +4,23 @@ set -e ROOT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -# We are currently standardized on using LLVM/Clang12 for this script. +# We are currently standardized on using LLVM/Clang14 for this script. # Note that this is totally independent of the version of LLVM that you -# are using to build Halide itself. If you don't have LLVM12 installed, +# are using to build Halide itself. If you don't have LLVM14 installed, # you can usually install what you need easily via: # -# sudo apt-get install llvm-12 clang-12 libclang-12-dev clang-tidy-12 -# export CLANG_FORMAT_LLVM_INSTALL_DIR=/usr/lib/llvm-12 +# sudo apt-get install llvm-14 clang-14 libclang-14-dev clang-tidy-14 +# export CLANG_FORMAT_LLVM_INSTALL_DIR=/usr/lib/llvm-14 [ -z "$CLANG_FORMAT_LLVM_INSTALL_DIR" ] && echo "CLANG_FORMAT_LLVM_INSTALL_DIR must point to an LLVM installation dir for this script." && exit echo CLANG_FORMAT_LLVM_INSTALL_DIR = ${CLANG_FORMAT_LLVM_INSTALL_DIR} VERSION=$(${CLANG_FORMAT_LLVM_INSTALL_DIR}/bin/clang-format --version) -if [[ ${VERSION} =~ .*version\ 12.* ]] +if [[ ${VERSION} =~ .*version\ 14.* ]] then - echo "clang-format version 12 found." + echo "clang-format version 14 found." else - echo "CLANG_FORMAT_LLVM_INSTALL_DIR must point to an LLVM 12 install!" + echo "CLANG_FORMAT_LLVM_INSTALL_DIR must point to an LLVM 14 install!" exit 1 fi diff --git a/run-clang-tidy.sh b/run-clang-tidy.sh index 379d1d0d40bd..beb73c04c375 100755 --- a/run-clang-tidy.sh +++ b/run-clang-tidy.sh @@ -8,23 +8,23 @@ ROOT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" FIX=$1 -# We are currently standardized on using LLVM/Clang12 for this script. +# We are currently standardized on using LLVM/Clang14 for this script. # Note that this is totally independent of the version of LLVM that you -# are using to build Halide itself. If you don't have LLVM12 installed, +# are using to build Halide itself. If you don't have LLVM14 installed, # you can usually install what you need easily via: # -# sudo apt-get install llvm-12 clang-12 libclang-12-dev clang-tidy-12 -# export CLANG_TIDY_LLVM_INSTALL_DIR=/usr/lib/llvm-12 +# sudo apt-get install llvm-14 clang-14 libclang-14-dev clang-tidy-14 +# export CLANG_TIDY_LLVM_INSTALL_DIR=/usr/lib/llvm-14 [ -z "$CLANG_TIDY_LLVM_INSTALL_DIR" ] && echo "CLANG_TIDY_LLVM_INSTALL_DIR must point to an LLVM installation dir for this script." && exit echo CLANG_TIDY_LLVM_INSTALL_DIR = ${CLANG_TIDY_LLVM_INSTALL_DIR} VERSION=$(${CLANG_TIDY_LLVM_INSTALL_DIR}/bin/clang-tidy --version) -if [[ ${VERSION} =~ .*version\ 12.* ]] +if [[ ${VERSION} =~ .*version\ 14.* ]] then - echo "clang-tidy version 12 found." + echo "clang-tidy version 14 found." else - echo "CLANG_TIDY_LLVM_INSTALL_DIR must point to an LLVM 12 install!" + echo "CLANG_TIDY_LLVM_INSTALL_DIR must point to an LLVM 14 install!" exit 1 fi @@ -49,7 +49,7 @@ cmake -DCMAKE_BUILD_TYPE=Debug \ # We must populate the includes directory to check things outside of src/ cmake --build ${CLANG_TIDY_BUILD_DIR} --target HalideIncludes -RUN_CLANG_TIDY=${CLANG_TIDY_LLVM_INSTALL_DIR}/share/clang/run-clang-tidy.py +RUN_CLANG_TIDY=${CLANG_TIDY_LLVM_INSTALL_DIR}/bin/run-clang-tidy # We deliberately skip apps/ and test/ for now, as the compile commands won't include # generated headers files from Generators. diff --git a/src/AbstractGenerator.cpp b/src/AbstractGenerator.cpp new file mode 100644 index 000000000000..52bd89553e38 --- /dev/null +++ b/src/AbstractGenerator.cpp @@ -0,0 +1,290 @@ +#include "AbstractGenerator.h" +#include "BoundaryConditions.h" +#include "Derivative.h" +#include "Generator.h" + +namespace Halide { +namespace Internal { + +namespace { + +Argument to_argument(const Internal::Parameter ¶m) { + return Argument(param.name(), + param.is_buffer() ? Argument::InputBuffer : Argument::InputScalar, + param.type(), + param.dimensions(), + param.get_argument_estimates()); +} + +} // namespace + +Module AbstractGenerator::build_module(const std::string &function_name) { + const LinkageType linkage_type = LinkageType::ExternalPlusMetadata; + + Pipeline pipeline = build_pipeline(); + + AutoSchedulerResults auto_schedule_results; + const auto context = this->context(); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + if (context.auto_schedule()) { + auto_schedule_results = pipeline.auto_schedule(context.target(), context.machine_params()); + } +#else + const auto &asp = context.autoscheduler_params(); + if (!asp.name.empty()) { + debug(1) << "Applying autoscheduler " << asp.name << " to Generator " << name() << " ...\n"; + auto_schedule_results = pipeline.apply_autoscheduler(context.target(), asp); + } else { + debug(1) << "Applying autoscheduler (NONE) to Generator " << name() << " ...\n"; + } +#endif + + std::vector filter_arguments; + const auto arg_infos = arginfos(); + for (const auto &a : arg_infos) { + if (a.dir != ArgInfoDirection::Input) { + continue; + } + for (const auto &p : input_parameter(a.name)) { + filter_arguments.push_back(to_argument(p)); + } + } + + Module result = pipeline.compile_to_module(filter_arguments, function_name, context.target(), linkage_type); +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE + for (const auto &map_entry : external_code_map()) { + result.append(map_entry.second); + } +#endif + + for (const auto &a : arg_infos) { + if (a.dir != ArgInfoDirection::Output) { + continue; + } + const std::vector output_funcs = output_func(a.name); + for (size_t i = 0; i < output_funcs.size(); ++i) { + const Func &f = output_funcs[i]; + + const std::string &from = f.name(); + std::string to = a.name; + if (output_funcs.size() > 1) { + to += "_" + std::to_string(i); + } + + const int tuple_size = f.outputs(); + for (int t = 0; t < tuple_size; ++t) { + const std::string suffix = (tuple_size > 1) ? ("." + std::to_string(t)) : ""; + result.remap_metadata_name(from + suffix, to + suffix); + } + } + } + + result.set_auto_scheduler_results(auto_schedule_results); + + return result; +} + +Module AbstractGenerator::build_gradient_module(const std::string &function_name) { + constexpr int DBG = 1; + + // I doubt these ever need customizing; if they do, we can make them arguments to this function. + const std::string grad_input_pattern = "_grad_loss_for_$OUT$"; + const std::string grad_output_pattern = "_grad_loss_$OUT$_wrt_$IN$"; + const LinkageType linkage_type = LinkageType::ExternalPlusMetadata; + + user_assert(!function_name.empty()) << "build_gradient_module(): function_name cannot be empty\n"; + + Pipeline original_pipeline = build_pipeline(); + + std::vector original_outputs = original_pipeline.outputs(); + + // Construct the adjoint pipeline, which has: + // - All the same inputs as the original, in the same order + // - Followed by one grad-input for each original output + // - Followed by one output for each unique pairing of original-output + original-input. + + // First: the original inputs. Note that scalar inputs remain scalar, + // rather being promoted into zero-dimensional buffers. + std::vector gradient_inputs; + const auto arg_infos = arginfos(); + for (const auto &a : arg_infos) { + if (a.dir != ArgInfoDirection::Input) { + continue; + } + for (const auto &p : input_parameter(a.name)) { + gradient_inputs.push_back(to_argument(p)); + debug(DBG) << " gradient copied input is: " << gradient_inputs.back().name << "\n"; + } + } + + // Next: add a grad-input for each *original* output; these will + // be the same shape as the output (so we should copy estimates from + // those outputs onto these estimates). + // - If an output is an Array, we'll have a separate input for each array element. + + std::vector d_output_imageparams; + for (const auto &a : arg_infos) { + if (a.dir != ArgInfoDirection::Output) { + continue; + } + for (const auto &f : output_func(a.name)) { + const Parameter &p = f.output_buffer().parameter(); + const std::string &output_name = p.name(); + // output_name is something like "funcname_i" + const std::string grad_in_name = replace_all(grad_input_pattern, "$OUT$", output_name); + // TODO(srj): does it make sense for gradient to be a non-float type? + // For now, assume it's always float32 (unless the output is already some float). + const Type grad_in_type = p.type().is_float() ? p.type() : Float(32); + const int grad_in_dimensions = p.dimensions(); + const ArgumentEstimates grad_in_estimates = p.get_argument_estimates(); + internal_assert((int)grad_in_estimates.buffer_estimates.size() == grad_in_dimensions); + + ImageParam d_im(grad_in_type, grad_in_dimensions, grad_in_name); + for (int d = 0; d < grad_in_dimensions; d++) { + d_im.parameter().set_min_constraint_estimate(d, grad_in_estimates.buffer_estimates.at(d).min); + d_im.parameter().set_extent_constraint_estimate(d, grad_in_estimates.buffer_estimates.at(d).extent); + } + d_output_imageparams.push_back(d_im); + gradient_inputs.push_back(to_argument(d_im.parameter())); + + debug(DBG) << " gradient synthesized input is: " << gradient_inputs.back().name << "\n"; + } + } + + // Finally: define the output Func(s), one for each unique output/input pair. + // Note that original_outputs.size() != pi.outputs().size() if any outputs are arrays. + internal_assert(original_outputs.size() == d_output_imageparams.size()) << "original_outputs.size() " << original_outputs.size() << " d_output_imageparams.size() " << d_output_imageparams.size(); + std::vector gradient_outputs; + for (size_t i = 0; i < original_outputs.size(); ++i) { + const Func &original_output = original_outputs.at(i); + const ImageParam &d_output = d_output_imageparams.at(i); + Region bounds; + for (int i = 0; i < d_output.dimensions(); i++) { + bounds.emplace_back(d_output.dim(i).min(), d_output.dim(i).extent()); + } + Func adjoint_func = BoundaryConditions::constant_exterior(d_output, make_zero(d_output.type())); + Derivative d = propagate_adjoints(original_output, adjoint_func, bounds); + + const std::string &output_name = original_output.name(); + for (const auto &a : arg_infos) { + if (a.dir != ArgInfoDirection::Input) { + continue; + } + for (const auto &p : input_parameter(a.name)) { + const std::string &input_name = p.name(); + + if (!p.is_buffer()) { + // Not sure if skipping scalar inputs is correct, but that's + // what the previous version of this code did, so we'll continue for now. + debug(DBG) << " Skipping scalar input " << output_name << " wrt input " << input_name << "\n"; + continue; + } + + // Note that Derivative looks up by name; we don't have the original + // Func, and we can't create a new one with an identical name (since + // Func's ctor will uniquify the name for us). Let's just look up + // by the original string instead. + Func d_f = d(input_name + "_im"); + + std::string grad_out_name = replace_all(replace_all(grad_output_pattern, "$OUT$", output_name), "$IN$", input_name); + if (!d_f.defined()) { + grad_out_name = "_dummy" + grad_out_name; + } + + Func d_out_wrt_in(grad_out_name); + if (d_f.defined()) { + d_out_wrt_in(Halide::_) = d_f(Halide::_); + } else { + debug(DBG) << " No Derivative found for output " << output_name << " wrt input " << input_name << "\n"; + // If there was no Derivative found, don't skip the output; + // just replace with a dummy Func that is all zeros. This ensures + // that the signature of the Pipeline we produce is always predictable. + std::vector vars; + for (int i = 0; i < d_output.dimensions(); i++) { + vars.push_back(Var::implicit(i)); + } + d_out_wrt_in(vars) = make_zero(d_output.type()); + } + + d_out_wrt_in.set_estimates(p.get_argument_estimates().buffer_estimates); + + // Useful for debugging; ordinarily better to leave out + // debug(0) << "\n\n" + // << "output:\n" << FuncWithDependencies(original_output) << "\n" + // << "d_output:\n" << FuncWithDependencies(adjoint_func) << "\n" + // << "input:\n" << FuncWithDependencies(f) << "\n" + // << "d_out_wrt_in:\n" << FuncWithDependencies(d_out_wrt_in) << "\n"; + + gradient_outputs.push_back(d_out_wrt_in); + debug(DBG) << " gradient output is: " << d_out_wrt_in.name() << "\n"; + } + } + } + + Pipeline grad_pipeline = Pipeline(gradient_outputs); + + AutoSchedulerResults auto_schedule_results; + const auto context = this->context(); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + if (context.auto_schedule()) { + auto_schedule_results = grad_pipeline.auto_schedule(context.target(), context.machine_params()); + } +#else + const auto &asp = context.autoscheduler_params(); + if (!asp.name.empty()) { + auto_schedule_results = grad_pipeline.apply_autoscheduler(context.target(), asp); + } +#endif + else { + user_warning << "Autoscheduling is not enabled in build_gradient_module(), so the resulting " + "gradient module will be unscheduled; this is very unlikely to be what you want.\n"; + } + + Module result = grad_pipeline.compile_to_module(gradient_inputs, function_name, context.target(), linkage_type); +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE + user_assert(external_code_map().empty()) + << "Building a gradient-descent module for a Generator with ExternalCode is not supported.\n"; +#endif + + result.set_auto_scheduler_results(auto_schedule_results); + return result; +} + +Callable AbstractGenerator::compile_to_callable(const JITHandlers *jit_handlers, + const std::map *jit_externs) { + Pipeline pipeline = build_pipeline(); + + std::vector arguments; + const auto arg_infos = arginfos(); + for (const auto &a : arg_infos) { + if (a.dir != ArgInfoDirection::Input) { + continue; + } + for (const auto &p : input_parameter(a.name)) { + arguments.push_back(to_argument(p)); + } + } + if (jit_handlers != nullptr) { + pipeline.jit_handlers() = *jit_handlers; + } + if (jit_externs != nullptr) { + pipeline.set_jit_externs(*jit_externs); + } + return pipeline.compile_to_callable(arguments, context().target()); +} + +void AbstractGenerator::set_generatorparam_values(const GeneratorParamsMap &m) { + for (const auto &c : m) { +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + user_assert(c.first != "target" && c.first != "auto_schedule" && c.first != "machine_params") + << "The GeneratorParam '" << c.first << "' cannot be specified via string here; use GeneratorContext instead."; +#else + user_assert(c.first != "target" && c.first != "auto_scheduler") + << "The GeneratorParam '" << c.first << "' cannot be specified via string here; use GeneratorContext instead."; +#endif + set_generatorparam_value(c.first, c.second); + } +} + +} // namespace Internal +} // namespace Halide diff --git a/src/AbstractGenerator.h b/src/AbstractGenerator.h new file mode 100644 index 000000000000..95e904dfd9aa --- /dev/null +++ b/src/AbstractGenerator.h @@ -0,0 +1,239 @@ +#ifndef HALIDE_ABSTRACT_GENERATOR_H_ +#define HALIDE_ABSTRACT_GENERATOR_H_ + +#include +#include +#include +#include + +#include "Callable.h" +#include "Expr.h" +#include "Func.h" +#include "Module.h" +#include "Parameter.h" +#include "Pipeline.h" +#include "Schedule.h" +#include "Target.h" +#include "Type.h" + +namespace Halide { + +class GeneratorContext; +using GeneratorParamsMap = std::map; + +namespace Internal { + +enum class ArgInfoKind { Scalar, + Function, + Buffer }; + +enum class ArgInfoDirection { Input, + Output }; + +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE +using ExternsMap = std::map; +#endif + +/** + * AbstractGenerator is an ABC that defines the API a Generator must provide + * to work with the existing Generator infrastructure (GenGen, RunGen, execute_generator(), + * Generator Stubs). The existing Generator<>-based instances all implement + * this API, but any other code that implements this (and uses RegisterGenerator + * to register itself) should be indistinguishable from a user perspective. + * + * An AbstractGenerator is meant to be "single-use"; typically lifetimes will be + * something like: + * - create an instance (with a specific Target) + * - optionally set GeneratorParam values + * - optionally re-bind inputs (if using in JIT or Stub modes) + * - call build_pipeline() + * - optionally call output_func() to get the output(s) (if using in JIT or Stub modes) + * - discard the instance + * + * AbstractGenerators should be fairly cheap to instantiate! Don't try to re-use + * one by re-setting inputs and calling build_pipeline() multiple times. + * + * Note that an AbstractGenerator instance is (generally) stateful in terms of the order + * that methods should be called; calling the methods out of order may cause + * assert-fails or other undesirable behavior. Read the method notes carefully! + */ +class AbstractGenerator { +public: + virtual ~AbstractGenerator() = default; + + /** ArgInfo is a struct to contain name-and-type information for the inputs and outputs to + * the Pipeline that build_pipeline() will return. + * + * Note that this looks rather similar to Halide::Argument, but unfortunately + * that is not a good fit here, as it cannot represent Func inputs (only + * Buffer and Scalar), nor can it really handle Outputs. + */ + struct ArgInfo { + std::string name; + ArgInfoDirection dir = ArgInfoDirection::Input; + ArgInfoKind kind = ArgInfoKind::Scalar; + // Note that this can have multiple entries for Tuple-valued Inputs or Outputs + std::vector types; + int dimensions = 0; + }; + + /** Return the name of this Generator. (This should always be the name + * used to register it.) */ + virtual std::string name() = 0; + + /** Return the Target and autoscheduler info that this Generator + * was created with. Always legal to call on any AbstractGenerator instance, + * regardless of what other methods have been called. (All AbstractGenerator instances + * are expected to be created with immutable values for these, which can't be + * changed for a given instance after creation. Note that Generator<> based subclasses + * can customize Target somewhat via init_from_context(); see Generator.h for more info.) + * + * CALL-AFTER: any + * CALL-BEFORE: any + */ + virtual GeneratorContext context() const = 0; + + /** Return a list of all the ArgInfos for this generator. The list will be in the order + * that the input and outputs are declared (possibly interleaved). + * Any inputs or outputs added by a configure() method will be in the list, + * at the end, in the order added. + * All input and output names will be unique within a given Generator instance. + * + * CALL-AFTER: configure() + * CALL-BEFORE: any + */ + virtual std::vector arginfos() = 0; + + /** Set the value for a specific GeneratorParam for an AbstractGenerator instance. + * + * Names that aren't known generator names should assert-fail. + * + * Values that can't be parsed for the specific GeneratorParam (e.g. passing "foo" where + * an integer is expected) should assert-fail at some point (either immediately, or when + * build_pipeline() is called) + * + * This can be called multiple times, but only prior to build_pipeline(). + * + * CALL-AFTER: none + * CALL-BEFORE: build_pipeline + */ + // @{ + virtual void set_generatorparam_value(const std::string &name, const std::string &value) = 0; + virtual void set_generatorparam_value(const std::string &name, const LoopLevel &loop_level) = 0; + // @} + + /** Build and return the Pipeline for this AbstractGenerator. This method should be called + * only once per instance. + * + * CALL-AFTER: set_generatorparam_value, bind_input + * CALL-BEFORE: input_parameter, output_func, external_code_map + */ + virtual Pipeline build_pipeline() = 0; + + /** Given the name of an input, return the Parameter(s) for that input. + * (Most inputs will have exactly one, but inputs that are declared as arrays + * will have multiple.) + * + * CALL-AFTER: build_pipeline + * CALL-BEFORE: none + */ + virtual std::vector input_parameter(const std::string &name) = 0; + + /** Given the name of an output, return the Func(s) for that output. + * + * Most outputs will have exactly one, but outputs that are declared as arrays will have multiple. + * + * Note that outputs with Tuple values are still just a single Func, though they do get realized + * as multiple Buffers. + * + * Must be called after build_pipeline(), since the output Funcs will be undefined prior to that. + * + * CALL-AFTER: build_pipeline() + * CALL-BEFORE: none + */ + virtual std::vector output_func(const std::string &name) = 0; + +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE + /** Return the ExternsMap for the Generator, if any. + * + * CALL-AFTER: build_pipeline() + * CALL-BEFORE: n/a + */ + virtual ExternsMap external_code_map() = 0; +#endif + + /** Rebind a specified Input to refer to the given piece of IR, replacing the + * default ImageParam / Param in place for that Input. Basic type-checking is + * done to ensure that inputs are still sane (e.g. types, dimensions, etc must match expectations). + * + * CALL-AFTER: set_generatorparam_value + * CALL-BEFORE: build_pipeline + */ + // @{ + virtual void bind_input(const std::string &name, const std::vector &v) = 0; + virtual void bind_input(const std::string &name, const std::vector &v) = 0; + virtual void bind_input(const std::string &name, const std::vector &v) = 0; + // @} + + /** Emit a Generator Stub (.stub.h) file to the given path. Not all Generators support this. + * + * If you call this method, you should not call any other AbstractGenerator methods + * on this instance, before or after this call. + * + * If the Generator is capable of emitting a Stub, do so and return true. (Errors + * during stub emission should assert-fail rather than returning false.) + * + * If the Generator is not capable of emitting a Stub, do nothing and return false. + * + * CALL-AFTER: none + * CALL-BEFORE: none + */ + virtual bool emit_cpp_stub(const std::string &stub_file_path) = 0; + + // Below are some concrete methods that build on top of the rest of the AbstractGenerator API. + // Note that they are nonvirtual. TODO: would these be better as freestanding methods that + // just take AbstractGeneratorPtr as arguments? + + /** Call generate() and produce a Module for the result. + *If function_name is empty, generator_name() will be used for the function. */ + Module build_module(const std::string &function_name = ""); + + /** + * Build a module that is suitable for using for gradient descent calculation in TensorFlow or PyTorch. + * + * Essentially: + * - A new Pipeline is synthesized from the current Generator (according to the rules below) + * - The new Pipeline is autoscheduled (if autoscheduling is requested, but it would be odd not to do so) + * - The Pipeline is compiled to a Module and returned + * + * The new Pipeline is adjoint to the original; it has: + * - All the same inputs as the original, in the same order + * - Followed by one grad-input for each original output + * - Followed by one output for each unique pairing of original-output + original-input. + * (For the common case of just one original-output, this amounts to being one output for each original-input.) + */ + Module build_gradient_module(const std::string &function_name); + + /** + * JIT the AbstractGenerator into a Callable (using the currently-set + * Target) and return it. + * + * If jit_handlers is not null, set the jitted func's jit_handlers to use a copy of it. + * + * If jit_externs is not null, use it to set the jitted func's external dependencies. + */ + Callable compile_to_callable(const JITHandlers *jit_handlers = nullptr, + const std::map *jit_externs = nullptr); + + /* + * Set all the GeneratorParams in the map. This is equivalent to simply calling the + * `set_generatorparam_value()` method in a loop over the map, but is quite convenient. */ + void set_generatorparam_values(const GeneratorParamsMap &m); +}; + +using AbstractGeneratorPtr = std::unique_ptr; + +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/AlignLoads.cpp b/src/AlignLoads.cpp index 31ee4aee3037..e3ca95acc542 100644 --- a/src/AlignLoads.cpp +++ b/src/AlignLoads.cpp @@ -78,8 +78,8 @@ class AlignLoads : public IRMutator { // non-constant strides. return IRMutator::visit(op); } - if (!(*const_stride == 1 || *const_stride == 2 || *const_stride == 3)) { - // Handle ramps with stride 1, 2 or 3 only. + if (!(*const_stride == 1 || *const_stride == 2 || *const_stride == 3 || *const_stride == 4)) { + // Handle ramps with stride 1, 2, 3 or 4 only. return IRMutator::visit(op); } diff --git a/src/Bounds.cpp b/src/Bounds.cpp index b20c9e5f257e..8090a20e721f 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -243,6 +243,20 @@ class Bounds : public IRVisitor { interval = Interval::single_point(op); } + void visit(const Reinterpret *op) override { + TRACK_BOUNDS_INTERVAL; + + Type t = op->type.element_of(); + + if (t.is_handle()) { + interval = Interval::everything(); + return; + } + + // Just use the bounds of the type + bounds_of_type(t); + } + void visit(const Cast *op) override { TRACK_BOUNDS_INTERVAL; op->value.accept(this); @@ -1196,6 +1210,20 @@ class Bounds : public IRVisitor { bounds_of_type(t); } } + } else if (op->is_intrinsic(Call::saturating_cast)) { + internal_assert(op->args.size() == 1); + + Expr a = op->args[0]; + a.accept(this); + Interval a_interval = interval; + bounds_of_type(t); + if (a_interval.has_lower_bound()) { + interval.min = saturating_cast(t, a_interval.min); + } + if (a_interval.has_upper_bound()) { + interval.max = saturating_cast(t, a_interval.max); + } + return; } else if (op->is_intrinsic(Call::unsafe_promise_clamped) || op->is_intrinsic(Call::promise_clamped)) { // Unlike an explicit clamp, we are also permitted to @@ -2640,7 +2668,7 @@ class BoxesTouched : public IRGraphVisitor { void visit(const IfThenElse *op) override { TRACK_BOXES_TOUCHED; op->condition.accept(this); - if (expr_uses_vars(op->condition, scope)) { + if (expr_uses_vars(op->condition, scope) || !is_pure(op->condition)) { // We need to simplify the condition to get it into a // canonical form (e.g. (a < b) instead of !(a >= b)) vector> cases; @@ -3558,6 +3586,28 @@ void bounds_test() { check(scope, cast(u8_1) + cast(u8_2), u16(0), u16(255 * 2)); + check(scope, saturating_cast(clamp(x, 5, 10)), cast(5), cast(10)); + { + scope.push("x", Interval(UInt(32).min(), UInt(32).max())); + check(scope, saturating_cast(max(cast(x), cast(5))), cast(5), Int(32).max()); + scope.pop("x"); + } + { + Expr z = Variable::make(Float(32), "z"); + scope.push("z", Interval(cast(-1), cast(1))); + check(scope, saturating_cast(z), cast(-1), cast(1)); + check(scope, saturating_cast(z), cast(-1), cast(1)); + check(scope, saturating_cast(z), cast(-1), cast(1)); + check(scope, saturating_cast(z), cast(0), cast(1)); + scope.pop("z"); + } + { + Expr z = Variable::make(UInt(32), "z"); + scope.push("z", Interval(UInt(32).max(), UInt(32).max())); + check(scope, saturating_cast(z), Int(32).max(), Int(32).max()); + scope.pop("z"); + } + { Scope scope; Expr x = Variable::make(UInt(16), "x"); diff --git a/src/BoundsInference.cpp b/src/BoundsInference.cpp index 298ad885fb10..f9d3784014ae 100644 --- a/src/BoundsInference.cpp +++ b/src/BoundsInference.cpp @@ -860,7 +860,7 @@ class BoundsInference : public IRMutator { // Do any pure inlining (TODO: This is currently slow) for (size_t i = f.size(); i > 0; i--) { - Function func = f[i - 1]; + const Function &func = f[i - 1]; if (inlined[i - 1]) { for (auto &s : stages) { for (auto &cond_val : s.exprs) { diff --git a/src/Buffer.h b/src/Buffer.h index eaff181f7fdc..220c009a8ea1 100644 --- a/src/Buffer.h +++ b/src/Buffer.h @@ -8,7 +8,9 @@ namespace Halide { -template +constexpr int AnyDims = Halide::Runtime::AnyDims; // -1 + +template class Buffer; struct JITUserContext; @@ -153,7 +155,7 @@ class Buffer { } public: - static constexpr int AnyDims = Halide::Runtime::AnyDims; + static constexpr int AnyDims = Halide::AnyDims; static_assert(Dims == AnyDims || Dims >= 0); typedef T ElemType; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 809b3efdd937..c31e37c32a20 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -4,7 +4,9 @@ # The externally-visible header files that go into making Halide.h. # Don't include anything here that includes llvm headers. +# Also *don't* include anything that's only used internally (eg SpirvIR.h). set(HEADER_FILES + AbstractGenerator.h AddAtomicMutex.h AddImageChecks.h AddParameterChecks.h @@ -21,6 +23,7 @@ set(HEADER_FILES BoundsInference.h BoundSmallAllocations.h Buffer.h + Callable.h CanonicalizeGPUVars.h ClampUnsafeAccesses.h Closure.h @@ -101,7 +104,6 @@ set(HEADER_FILES LowerParallelTasks.h LowerWarpShuffles.h MainPage.h - MatlabWrapper.h Memoization.h Module.h ModulusRemainder.h @@ -167,6 +169,7 @@ set(HEADER_FILES ) set(SOURCE_FILES + AbstractGenerator.cpp AddAtomicMutex.cpp AddImageChecks.cpp AddParameterChecks.cpp @@ -183,6 +186,7 @@ set(SOURCE_FILES BoundsInference.cpp BoundSmallAllocations.cpp Buffer.cpp + Callable.cpp CanonicalizeGPUVars.cpp ClampUnsafeAccesses.cpp Closure.cpp @@ -261,7 +265,6 @@ set(SOURCE_FILES Lower.cpp LowerParallelTasks.cpp LowerWarpShuffles.cpp - MatlabWrapper.cpp Memoization.cpp Module.cpp ModulusRemainder.cpp @@ -297,6 +300,7 @@ set(SOURCE_FILES Simplify_Add.cpp Simplify_And.cpp Simplify_Call.cpp + Simplify_Reinterpret.cpp Simplify_Cast.cpp Simplify_Div.cpp Simplify_EQ.cpp @@ -318,6 +322,7 @@ set(SOURCE_FILES SkipStages.cpp SlidingWindow.cpp Solve.cpp + SpirvIR.cpp SplitTuples.cpp StmtToHtml.cpp StorageFlattening.cpp @@ -459,7 +464,6 @@ target_compile_options( $<$:-Wno-missing-prototypes> $<$:-Wno-nonportable-system-include-path> $<$:-Wno-reserved-id-macro> - $<$:-Wno-return-std-move-in-c++11> $<$:-Wno-shadow-field-in-constructor> $<$:-Wno-shadow-field> $<$:-Wno-shorten-64-to-32> @@ -467,6 +471,9 @@ target_compile_options( $<$:-Wno-unused-member-function> $<$:-Wno-unused-template> + # This warning was removed in Clang 13 + $<$,$,13.0>>:-Wno-return-std-move-in-c++11> + $<$:/W3> $<$:/wd4018> # 4018: disable "signed/unsigned mismatch" $<$:/wd4141> # 4141: 'inline' used more than once @@ -519,6 +526,11 @@ if (TARGET_OPENGLCOMPUTE) target_compile_definitions(Halide PRIVATE WITH_OPENGLCOMPUTE) endif () +if (TARGET_SPIRV) + target_compile_definitions(Halide PRIVATE WITH_SPIRV) + target_include_directories(Halide SYSTEM PRIVATE "${SPIRV_INCLUDE_DIR}") +endif () + ## # Add autoschedulers to the build. ## diff --git a/src/Callable.cpp b/src/Callable.cpp new file mode 100644 index 000000000000..6ba334619d22 --- /dev/null +++ b/src/Callable.cpp @@ -0,0 +1,209 @@ +#include + +#include "Argument.h" +#include "Callable.h" +#include "JITModule.h" +#include "Pipeline.h" + +using namespace Halide::Internal; + +namespace Halide { + +struct CallableContents { + mutable RefCount ref_count; + + // Name of the jitted function, here solely for error reporting + std::string name; + + // The cached code + JITCache jit_cache; + + // Save the jit_handlers and jit_externs as they were at the time this + // Callable was created, in case the Pipeline's version is mutated in + // between creation and call -- we want the Callable to remain immutable + // after creation, regardless of what you do to the Func. + JITHandlers saved_jit_handlers; + std::map saved_jit_externs; + + // Encoded values for efficient runtime type checking; + // identical to jit_cache.arguments in length. + std::vector quick_call_check_info; + + // Encoded values for complete runtime type checking, used + // only for make_std_function. Lazily created. + std::vector full_call_check_info; +}; + +namespace Internal { +template<> +RefCount &ref_count(const CallableContents *p) noexcept { + return p->ref_count; +} + +template<> +void destroy(const CallableContents *p) { + delete p; +} +} // namespace Internal + +Callable::Callable() + : contents(new CallableContents) { +} + +Callable::Callable(const std::string &name, + const JITHandlers &jit_handlers, + const std::map &jit_externs, + JITCache &&jit_cache) + : contents(new CallableContents) { + contents->name = name; + contents->jit_cache = std::move(jit_cache); + contents->saved_jit_handlers = jit_handlers; + contents->saved_jit_externs = jit_externs; + + contents->quick_call_check_info.reserve(contents->jit_cache.arguments.size()); + for (const Argument &a : contents->jit_cache.arguments) { + const auto qcci = (a.name == "__user_context") ? + Callable::make_ucon_qcci() : + (a.is_scalar() ? Callable::make_scalar_qcci(a.type) : Callable::make_buffer_qcci()); + contents->quick_call_check_info.push_back(qcci); + } + + // Don't create full_call_check_info yet. +} + +const std::vector &Callable::arguments() const { + return contents->jit_cache.arguments; +} + +Callable::FailureFn Callable::do_check_fail(int bad_idx, size_t argc, const char *verb) const { + const size_t required_arg_count = contents->jit_cache.arguments.size(); + + // TODO: this assumes that the caller uses the no-explicit-JITUserContext call; + // the errors will be misleading otherwise. + constexpr int hidden_args = 1; + + std::ostringstream o; + if (bad_idx < 0) { + o << "Error " << verb << " '" << contents->name << "': " + << "Expected exactly " << (required_arg_count - hidden_args) << " arguments, " + << "but saw " << (argc - hidden_args) << "."; + } else { + // Capture *this to ensure that the CallableContents stay valid as long as the std::function does + const Argument &a = contents->jit_cache.arguments.at(bad_idx); + const char *kind = a.is_scalar() ? "scalar" : "buffer"; + // Note that we don't report the "actual type" here, just the expected type... + // saving the actual type leads to more code bloat than we can justify + // for this. (Consider adding as a debug-only enhancement?) + o << "Error " << verb << " '" << contents->name << "': " + << "Argument " << (bad_idx - hidden_args + 1) + << " of " << (required_arg_count - hidden_args) << " ('" << a.name << "') was expected to be a " + << kind << " of type '" << a.type << "' and dimension " << (int)a.dimensions << ".\n"; + } + std::string msg = o.str(); + + return [*this, msg](JITUserContext *context) -> int { + constexpr int exit_status = halide_error_code_internal_error; // TODO: perhaps return a more useful error code?; + + if (context && context->handlers.custom_error) { + context->handlers.custom_error(context, msg.c_str()); + } else if (contents->saved_jit_handlers.custom_error) { + contents->saved_jit_handlers.custom_error(context, msg.c_str()); + } else { + if (msg.empty()) { + halide_runtime_error << "The pipeline returned exit status " << exit_status << " but halide_error was never called.\n"; + } else { + halide_runtime_error << msg; + } + } + return exit_status; + }; +} + +Callable::FailureFn Callable::check_qcci(size_t argc, const QuickCallCheckInfo *actual_qcci) const { + const size_t required_arg_count = contents->quick_call_check_info.size(); + if (argc == required_arg_count) { + const QuickCallCheckInfo *expected_qcci = contents->quick_call_check_info.data(); + for (size_t i = 0; i < argc; i++) { + if (actual_qcci[i] != expected_qcci[i]) { + return do_check_fail(i, argc, "calling"); + } + } + } else { + return do_check_fail(-1, argc, "calling"); + } + + return nullptr; +} + +Callable::FailureFn Callable::check_fcci(size_t argc, const FullCallCheckInfo *actual_fcci) const { + // Lazily create full_call_check_info upon the first call to make_std_function(). + if (contents->full_call_check_info.empty()) { + contents->full_call_check_info.reserve(contents->jit_cache.arguments.size()); + for (const Argument &a : contents->jit_cache.arguments) { + const auto fcci = a.is_scalar() ? Callable::make_scalar_fcci(a.type) : Callable::make_buffer_fcci(a.type, a.dimensions); + contents->full_call_check_info.push_back(fcci); + } + } + + FailureFn failure_fn = nullptr; + const size_t required_arg_count = contents->full_call_check_info.size(); + if (argc == required_arg_count) { + const FullCallCheckInfo *expected_fcci = contents->full_call_check_info.data(); + for (size_t i = 0; i < argc; i++) { + if (!Callable::is_compatible_fcci(actual_fcci[i], expected_fcci[i])) { + failure_fn = do_check_fail(i, argc, "defining"); + break; + } + } + } else { + failure_fn = do_check_fail(-1, argc, "defining"); + } + + if (failure_fn) { + // Go ahead and call it now, since we know that every possible call will fail. + // (We'll also return it as a sentinel so the caller knows that this is the case; + // if the Callable has hooked the error handler to do nothing, we don't want want + // to try to continue executing this path in the caller.) + JITUserContext empty; + (void)failure_fn(&empty); + } + + return failure_fn; +} + +// Entry point used from the std::function<> variant; we can skip the check_qcci() stuff +// since we verified the signature when we created the std::function, so incorrect types or counts +// should be impossible. +/*static*/ int Callable::call_argv_fast(size_t argc, const void *const *argv) const { + // Callable should enforce these, so we can use assert() instead of internal_assert() -- + // this is effectively just documentation that these invariants are expected to have + // been enforced prior to this call. + assert(contents->jit_cache.jit_target.has_feature(Target::UserContext)); + assert(contents->jit_cache.arguments[0].name == "__user_context"); + + JITUserContext *context = *(JITUserContext **)const_cast(argv[0]); + assert(context != nullptr); + + JITFuncCallContext jit_call_context(context, contents->saved_jit_handlers); + + int exit_status = contents->jit_cache.call_jit_code(contents->jit_cache.jit_target, argv); + + // If we're profiling, report runtimes and reset profiler stats. + contents->jit_cache.finish_profiling(context); + + jit_call_context.finalize(exit_status); + + return exit_status; +} + +int Callable::call_argv_checked(size_t argc, const void *const *argv, const QuickCallCheckInfo *actual_qcci) const { + // It's *essential* we call this for safety. + const auto failure_fn = check_qcci(argc, actual_qcci); + if (failure_fn) { + JITUserContext *context = *(JITUserContext **)const_cast(argv[0]); + return failure_fn(context); + } + return call_argv_fast(argc, argv); +} + +} // namespace Halide diff --git a/src/Callable.h b/src/Callable.h new file mode 100644 index 000000000000..440ba3aab21d --- /dev/null +++ b/src/Callable.h @@ -0,0 +1,387 @@ +#ifndef HALIDE_CALLABLE_H +#define HALIDE_CALLABLE_H + +/** \file + * + * Defines the front-end class representing a jitted, callable Halide pipeline. + */ + +#include +#include + +#include "Buffer.h" +#include "IntrusivePtr.h" +#include "JITModule.h" + +namespace Halide { + +struct Argument; +struct CallableContents; + +namespace PythonBindings { +class PyCallable; +} + +namespace Internal { + +template +struct IsHalideBuffer : std::false_type {}; + +template +struct IsHalideBuffer<::Halide::Buffer> : std::true_type {}; + +template +struct IsHalideBuffer<::Halide::Runtime::Buffer> : std::true_type {}; + +template<> +struct IsHalideBuffer : std::true_type {}; + +template<> +struct IsHalideBuffer : std::true_type {}; + +template +struct HalideBufferStaticTypeAndDims { + static constexpr halide_type_t type() { + return halide_type_t(); + } + static constexpr int dims() { + return -1; + } +}; + +template +struct HalideBufferStaticTypeAndDims<::Halide::Buffer> { + static constexpr halide_type_t type() { + if constexpr (std::is_void_v) { + return halide_type_t(); + } else { + return halide_type_of(); + } + } + static constexpr int dims() { + return Dims; + } +}; + +template +struct HalideBufferStaticTypeAndDims<::Halide::Runtime::Buffer> { + static constexpr halide_type_t type() { + if constexpr (std::is_void_v) { + return halide_type_t(); + } else { + return halide_type_of(); + } + } + static constexpr int dims() { + return Dims; + } +}; + +} // namespace Internal + +class Callable { +private: + friend class Pipeline; + friend struct CallableContents; + friend class PythonBindings::PyCallable; + + Internal::IntrusivePtr contents; + + // --------------------------------- + + // This value is constructed so we can do the necessary runtime check + // with a single 16-bit compare. It's designed to to the minimal checking + // necessary to ensure that the arguments are well-formed, but not necessarily + // "correct"; in particular, it deliberately skips checking type-and-dim + // of Buffer arguments, since the generated code has assertions to check + // for that anyway. + using QuickCallCheckInfo = uint16_t; + + static constexpr QuickCallCheckInfo _make_qcci(uint8_t code, uint8_t bits) { + return (((uint16_t)code) << 8) | (uint16_t)bits; + } + + static constexpr QuickCallCheckInfo make_scalar_qcci(halide_type_t t) { + return _make_qcci(t.code, t.bits); + } + + static constexpr QuickCallCheckInfo make_buffer_qcci() { + constexpr uint8_t fake_bits_buffer_cci = 3; + return _make_qcci(halide_type_handle, fake_bits_buffer_cci); + } + + static constexpr QuickCallCheckInfo make_ucon_qcci() { + constexpr uint8_t fake_bits_ucon_cci = 5; + return _make_qcci(halide_type_handle, fake_bits_ucon_cci); + } + + template + static constexpr QuickCallCheckInfo make_qcci() { + using T0 = typename std::remove_const::type>::type; + if constexpr (std::is_same::value) { + return make_ucon_qcci(); + } else if constexpr (Internal::IsHalideBuffer::value) { + // Don't bother checking type-and-dimensions here (the callee will do that) + return make_buffer_qcci(); + } else if constexpr (std::is_arithmetic::value || std::is_pointer::value) { + return make_scalar_qcci(halide_type_of()); + } else { + // static_assert(false) will fail all the time, even inside constexpr, + // but gating on sizeof(T) is a nice trick that ensures we will always + // fail here (since no T is ever size 0). + static_assert(!sizeof(T), "Illegal type passed to Callable."); + } + } + + template + static constexpr std::array make_qcci_array() { + return std::array{make_qcci()...}; + } + + // --------------------------------- + + // This value is constructed so we can do a complete type-and-dim check + // of Buffers, and is used for the make_std_function() method, to ensure + // that if we specify static type-and-dims for Buffers, the ones we specify + // actually match the underlying code. We take horrible liberties with halide_type_t + // to make this happen -- specifically, encoding dimensionality and buffer-vs-scalar + // into the 'lanes' field -- but that's ok since this never escapes into other usage. + using FullCallCheckInfo = halide_type_t; + + static constexpr FullCallCheckInfo _make_fcci(halide_type_t type, int dims, bool is_buffer) { + return type.with_lanes(((uint16_t)dims << 1) | (uint16_t)(is_buffer ? 1 : 0)); + } + + static constexpr FullCallCheckInfo make_scalar_fcci(halide_type_t t) { + return _make_fcci(t, 0, false); + } + + static constexpr FullCallCheckInfo make_buffer_fcci(halide_type_t t, int dims) { + return _make_fcci(t, dims, true); + } + + static bool is_compatible_fcci(FullCallCheckInfo actual, FullCallCheckInfo expected) { + if (actual == expected) { + return true; // my, that was easy + } + + // Might still be compatible + const bool a_is_buffer = (actual.lanes & 1) != 0; + const int a_dims = (((int16_t)actual.lanes) >> 1); + const halide_type_t a_type = actual.with_lanes(0); + + const bool e_is_buffer = (expected.lanes & 1) != 0; + const int e_dims = (((int16_t)expected.lanes) >> 1); + const halide_type_t e_type = expected.with_lanes(0); + + const bool types_match = (a_type == halide_type_t()) || + (e_type == halide_type_t()) || + (a_type == e_type); + + const bool dims_match = a_dims < 0 || + e_dims < 0 || + a_dims == e_dims; + + return a_is_buffer == e_is_buffer && types_match && dims_match; + } + + template + static constexpr FullCallCheckInfo make_fcci() { + using T0 = typename std::remove_const::type>::type; + if constexpr (Internal::IsHalideBuffer::value) { + using TypeAndDims = Internal::HalideBufferStaticTypeAndDims; + return make_buffer_fcci(TypeAndDims::type(), TypeAndDims::dims()); + } else if constexpr (std::is_arithmetic::value || std::is_pointer::value) { + return make_scalar_fcci(halide_type_of()); + } else { + // static_assert(false) will fail all the time, even inside constexpr, + // but gating on sizeof(T) is a nice trick that ensures we will always + // fail here (since no T is ever size 0). + static_assert(!sizeof(T), "Illegal type passed to Callable."); + } + } + + template + static constexpr std::array make_fcci_array() { + return std::array{make_fcci()...}; + } + + // --------------------------------- + + template + struct ArgvStorage { + const void *argv[Size]; + // We need a place to store the scalar inputs, since we need a pointer + // to them and it's better to avoid relying on stack spill of arguments. + // Note that this will usually have unused slots, but it's cheap and easy + // compile-time allocation on the stack. + uintptr_t argv_scalar_store[Size]; + + template + explicit ArgvStorage(Args &&...args) { + fill_slots(0, std::forward(args)...); + } + + private: + template + HALIDE_ALWAYS_INLINE void fill_slot(size_t idx, const ::Halide::Buffer &value) { + // Don't call ::Halide::Buffer::raw_buffer(): it includes "user_assert(defined())" + // as part of the wrapper code, and we want this lean-and-mean. Instead, stick in a null + // value for undefined buffers, and let the Halide pipeline fail with the usual null-ptr + // check. (Note that H::R::B::get() *never* returns null; you must check defined() first.) + argv[idx] = value.defined() ? value.get()->raw_buffer() : nullptr; + } + + template + HALIDE_ALWAYS_INLINE void fill_slot(size_t idx, const ::Halide::Runtime::Buffer &value) { + argv[idx] = value.raw_buffer(); + } + + HALIDE_ALWAYS_INLINE + void fill_slot(size_t idx, halide_buffer_t *value) { + argv[idx] = value; + } + + HALIDE_ALWAYS_INLINE + void fill_slot(size_t idx, const halide_buffer_t *value) { + argv[idx] = value; + } + + HALIDE_ALWAYS_INLINE + void fill_slot(size_t idx, JITUserContext *value) { + auto *dest = &argv_scalar_store[idx]; + *dest = (uintptr_t)value; + argv[idx] = dest; + } + + template + HALIDE_ALWAYS_INLINE void fill_slot(size_t idx, const T &value) { + auto *dest = &argv_scalar_store[idx]; + *(T *)dest = value; + argv[idx] = dest; + } + + template + HALIDE_ALWAYS_INLINE void fill_slots(size_t idx, const T &value) { + fill_slot(idx, value); + } + + template + HALIDE_ALWAYS_INLINE void fill_slots(int idx, First &&first, Second &&second, Rest &&...rest) { + fill_slots(idx, std::forward(first)); + fill_slots(idx + 1, std::forward(second), std::forward(rest)...); + } + }; + + Callable(); + Callable(const std::string &name, + const JITHandlers &jit_handlers, + const std::map &jit_externs, + Internal::JITCache &&jit_cache); + + // Note that the first entry in argv must always be a JITUserContext*. + int call_argv_checked(size_t argc, const void *const *argv, const QuickCallCheckInfo *actual_cci) const; + int call_argv_fast(size_t argc, const void *const *argv) const; + + using FailureFn = std::function; + + FailureFn do_check_fail(int bad_idx, size_t argc, const char *verb) const; + FailureFn check_qcci(size_t argc, const QuickCallCheckInfo *actual_cci) const; + FailureFn check_fcci(size_t argc, const FullCallCheckInfo *actual_cci) const; + + template + int call(JITUserContext *context, Args &&...args) const { + // This is built at compile time! + static constexpr auto actual_arg_types = make_qcci_array(); + + constexpr size_t count = sizeof...(args) + 1; + ArgvStorage argv(context, std::forward(args)...); + return call_argv_checked(count, &argv.argv[0], actual_arg_types.data()); + } + + /** Return the expected Arguments for this Callable, in the order they must be specified, including all outputs. + * Note that the first entry will *always* specify a JITUserContext. */ + const std::vector &arguments() const; + +public: + template + HALIDE_FUNCTION_ATTRS int + operator()(JITUserContext *context, Args &&...args) const { + return call(context, std::forward(args)...); + } + + template + HALIDE_FUNCTION_ATTRS int + operator()(Args &&...args) const { + JITUserContext empty; + return call(&empty, std::forward(args)...); + } + + /** This allows us to construct a std::function<> that wraps the Callable. + * This is nice in that it is, well, just a std::function, but also in that + * since the argument-count-and-type checking are baked into the language, + * we can do the relevant checking only once -- when we first create the std::function -- + * and skip it on all actual *calls* to the function, making it slightly more efficient. + * It's also more type-forgiving, in that the usual C++ numeric coercion rules apply here. + * + * The downside is that there isn't (currently) any way to automatically infer + * the static types reliably, since we may be using (e.g.) a Param, where the + * type in question isn't available to the C++ compiler. This means that the coder + * must supply the correct type signature when calling this function -- but the good news + * is that if you get it wrong, this function will fail when you call it. (In other words: + * it can't choose the right thing for you, but it can tell you when you do the wrong thing.) + * + * TODO: it's possible that we could infer the correct signatures in some cases, + * and only fail for the ambiguous cases, but that would require a lot more template-fu + * here and elsewhere. I think this is good enough for now. + * + * TODO: is it possible to annotate the result of a std::function<> with HALIDE_FUNCTION_ATTRS? + */ + template + std::function + make_std_function() const { + if constexpr (std::is_same_v) { + constexpr auto actual_arg_types = make_fcci_array(); + const auto failure_fn = check_fcci(actual_arg_types.size(), actual_arg_types.data()); + if (failure_fn) { + // Return a wrapper for the failure_fn in case the error handler is a no-op, + // so that subsequent calls won't attempt to use possibly-wrong argv packing. + return [*this, failure_fn](auto &&first, auto &&...rest) -> int { + return failure_fn(std::forward(first)); + }; + } + + // Capture *this to ensure that the CallableContents stay valid as long as the std::function does + return [*this](auto &&first, auto &&...rest) -> int { + constexpr size_t count = 1 + sizeof...(rest); + ArgvStorage argv(std::forward(first), std::forward(rest)...); + return call_argv_fast(count, &argv.argv[0]); + }; + } else { + // Explicitly prepend JITUserContext* as first actual-arg-type. + constexpr auto actual_arg_types = make_fcci_array(); + const auto failure_fn = check_fcci(actual_arg_types.size(), actual_arg_types.data()); + if (failure_fn) { + // Return a wrapper for the failure_fn in case the error handler is a no-op, + // so that subsequent calls won't attempt to use possibly-wrong argv packing. + return [*this, failure_fn](auto &&first, auto &&...rest) -> int { + JITUserContext empty; + return failure_fn(&empty); + }; + } + + // Capture *this to ensure that the CallableContents stay valid as long as the std::function does + return [*this](auto &&first, auto &&...rest) -> int { + // Explicitly prepend an (empty) JITUserContext to the args. + JITUserContext empty; + constexpr size_t count = 1 + 1 + sizeof...(rest); + ArgvStorage argv(&empty, std::forward(first), std::forward(rest)...); + return call_argv_fast(count, &argv.argv[0]); + }; + } + } +}; + +} // namespace Halide + +#endif diff --git a/src/CodeGen_ARM.cpp b/src/CodeGen_ARM.cpp index 02b2ebac4bf2..06acff932f9e 100644 --- a/src/CodeGen_ARM.cpp +++ b/src/CodeGen_ARM.cpp @@ -72,7 +72,8 @@ class CodeGen_ARM : public CodeGen_Posix { }; vector casts, calls, averagings, negations; - string mcpu() const override; + string mcpu_target() const override; + string mcpu_tune() const override; string mattrs() const override; bool use_soft_float_abi() const override; int native_vector_bits() const override; @@ -141,43 +142,43 @@ CodeGen_ARM::CodeGen_ARM(const Target &target) // TODO: We need to match rounding shift right, and negate the RHS. // SQRSHRN, SQRSHRUN, UQRSHRN - Saturating rounding narrowing shift right narrow (by immediate in [1, output bits]) - casts.emplace_back("saturating_rounding_shift_right_narrow", i8_sat(rounding_shift_right(wild_i16x_, wild_u16_))); - casts.emplace_back("saturating_rounding_shift_right_narrow", u8_sat(rounding_shift_right(wild_u16x_, wild_u16_))); - casts.emplace_back("saturating_rounding_shift_right_narrow", u8_sat(rounding_shift_right(wild_i16x_, wild_u16_))); - casts.emplace_back("saturating_rounding_shift_right_narrow", i16_sat(rounding_shift_right(wild_i32x_, wild_u32_))); - casts.emplace_back("saturating_rounding_shift_right_narrow", u16_sat(rounding_shift_right(wild_u32x_, wild_u32_))); - casts.emplace_back("saturating_rounding_shift_right_narrow", u16_sat(rounding_shift_right(wild_i32x_, wild_u32_))); - casts.emplace_back("saturating_rounding_shift_right_narrow", i32_sat(rounding_shift_right(wild_i64x_, wild_u64_))); - casts.emplace_back("saturating_rounding_shift_right_narrow", u32_sat(rounding_shift_right(wild_u64x_, wild_u64_))); - casts.emplace_back("saturating_rounding_shift_right_narrow", u32_sat(rounding_shift_right(wild_i64x_, wild_u64_))); + calls.emplace_back("saturating_rounding_shift_right_narrow", i8_sat(rounding_shift_right(wild_i16x_, wild_u16_))); + calls.emplace_back("saturating_rounding_shift_right_narrow", u8_sat(rounding_shift_right(wild_u16x_, wild_u16_))); + calls.emplace_back("saturating_rounding_shift_right_narrow", u8_sat(rounding_shift_right(wild_i16x_, wild_u16_))); + calls.emplace_back("saturating_rounding_shift_right_narrow", i16_sat(rounding_shift_right(wild_i32x_, wild_u32_))); + calls.emplace_back("saturating_rounding_shift_right_narrow", u16_sat(rounding_shift_right(wild_u32x_, wild_u32_))); + calls.emplace_back("saturating_rounding_shift_right_narrow", u16_sat(rounding_shift_right(wild_i32x_, wild_u32_))); + calls.emplace_back("saturating_rounding_shift_right_narrow", i32_sat(rounding_shift_right(wild_i64x_, wild_u64_))); + calls.emplace_back("saturating_rounding_shift_right_narrow", u32_sat(rounding_shift_right(wild_u64x_, wild_u64_))); + calls.emplace_back("saturating_rounding_shift_right_narrow", u32_sat(rounding_shift_right(wild_i64x_, wild_u64_))); // SQSHL, UQSHL, SQSHLU - Saturating shift left by signed register. for (const Expr &rhs : {wild_i8x_, wild_u8x_}) { - casts.emplace_back("saturating_shift_left", i8_sat(widening_shift_left(wild_i8x_, rhs))); - casts.emplace_back("saturating_shift_left", u8_sat(widening_shift_left(wild_u8x_, rhs))); - casts.emplace_back("saturating_shift_left", u8_sat(widening_shift_left(wild_i8x_, rhs))); + calls.emplace_back("saturating_shift_left", i8_sat(widening_shift_left(wild_i8x_, rhs))); + calls.emplace_back("saturating_shift_left", u8_sat(widening_shift_left(wild_u8x_, rhs))); + calls.emplace_back("saturating_shift_left", u8_sat(widening_shift_left(wild_i8x_, rhs))); } for (const Expr &rhs : {wild_i16x_, wild_u16x_}) { - casts.emplace_back("saturating_shift_left", i16_sat(widening_shift_left(wild_i16x_, rhs))); - casts.emplace_back("saturating_shift_left", u16_sat(widening_shift_left(wild_u16x_, rhs))); - casts.emplace_back("saturating_shift_left", u16_sat(widening_shift_left(wild_i16x_, rhs))); + calls.emplace_back("saturating_shift_left", i16_sat(widening_shift_left(wild_i16x_, rhs))); + calls.emplace_back("saturating_shift_left", u16_sat(widening_shift_left(wild_u16x_, rhs))); + calls.emplace_back("saturating_shift_left", u16_sat(widening_shift_left(wild_i16x_, rhs))); } for (const Expr &rhs : {wild_i32x_, wild_u32x_}) { - casts.emplace_back("saturating_shift_left", i32_sat(widening_shift_left(wild_i32x_, rhs))); - casts.emplace_back("saturating_shift_left", u32_sat(widening_shift_left(wild_u32x_, rhs))); - casts.emplace_back("saturating_shift_left", u32_sat(widening_shift_left(wild_i32x_, rhs))); + calls.emplace_back("saturating_shift_left", i32_sat(widening_shift_left(wild_i32x_, rhs))); + calls.emplace_back("saturating_shift_left", u32_sat(widening_shift_left(wild_u32x_, rhs))); + calls.emplace_back("saturating_shift_left", u32_sat(widening_shift_left(wild_i32x_, rhs))); } // SQSHRN, UQSHRN, SQRSHRUN Saturating narrowing shift right by an (by immediate in [1, output bits]) - casts.emplace_back("saturating_shift_right_narrow", i8_sat(wild_i16x_ >> wild_u16_)); - casts.emplace_back("saturating_shift_right_narrow", u8_sat(wild_u16x_ >> wild_u16_)); - casts.emplace_back("saturating_shift_right_narrow", u8_sat(wild_i16x_ >> wild_u16_)); - casts.emplace_back("saturating_shift_right_narrow", i16_sat(wild_i32x_ >> wild_u32_)); - casts.emplace_back("saturating_shift_right_narrow", u16_sat(wild_u32x_ >> wild_u32_)); - casts.emplace_back("saturating_shift_right_narrow", u16_sat(wild_i32x_ >> wild_u32_)); - casts.emplace_back("saturating_shift_right_narrow", i32_sat(wild_i64x_ >> wild_u64_)); - casts.emplace_back("saturating_shift_right_narrow", u32_sat(wild_u64x_ >> wild_u64_)); - casts.emplace_back("saturating_shift_right_narrow", u32_sat(wild_i64x_ >> wild_u64_)); + calls.emplace_back("saturating_shift_right_narrow", i8_sat(wild_i16x_ >> wild_u16_)); + calls.emplace_back("saturating_shift_right_narrow", u8_sat(wild_u16x_ >> wild_u16_)); + calls.emplace_back("saturating_shift_right_narrow", u8_sat(wild_i16x_ >> wild_u16_)); + calls.emplace_back("saturating_shift_right_narrow", i16_sat(wild_i32x_ >> wild_u32_)); + calls.emplace_back("saturating_shift_right_narrow", u16_sat(wild_u32x_ >> wild_u32_)); + calls.emplace_back("saturating_shift_right_narrow", u16_sat(wild_i32x_ >> wild_u32_)); + calls.emplace_back("saturating_shift_right_narrow", i32_sat(wild_i64x_ >> wild_u64_)); + calls.emplace_back("saturating_shift_right_narrow", u32_sat(wild_u64x_ >> wild_u64_)); + calls.emplace_back("saturating_shift_right_narrow", u32_sat(wild_i64x_ >> wild_u64_)); // SRSHL, URSHL - Rounding shift left (by signed vector) // These are already written as rounding_shift_left @@ -189,15 +190,15 @@ CodeGen_ARM::CodeGen_ARM(const Target &target) // These patterns are almost identity, we just need to strip off the broadcast. // SQXTN, UQXTN, SQXTUN - Saturating narrow. - casts.emplace_back("saturating_narrow", i8_sat(wild_i16x_)); - casts.emplace_back("saturating_narrow", u8_sat(wild_u16x_)); - casts.emplace_back("saturating_narrow", u8_sat(wild_i16x_)); - casts.emplace_back("saturating_narrow", i16_sat(wild_i32x_)); - casts.emplace_back("saturating_narrow", u16_sat(wild_u32x_)); - casts.emplace_back("saturating_narrow", u16_sat(wild_i32x_)); - casts.emplace_back("saturating_narrow", i32_sat(wild_i64x_)); - casts.emplace_back("saturating_narrow", u32_sat(wild_u64x_)); - casts.emplace_back("saturating_narrow", u32_sat(wild_i64x_)); + calls.emplace_back("saturating_narrow", i8_sat(wild_i16x_)); + calls.emplace_back("saturating_narrow", u8_sat(wild_u16x_)); + calls.emplace_back("saturating_narrow", u8_sat(wild_i16x_)); + calls.emplace_back("saturating_narrow", i16_sat(wild_i32x_)); + calls.emplace_back("saturating_narrow", u16_sat(wild_u32x_)); + calls.emplace_back("saturating_narrow", u16_sat(wild_i32x_)); + calls.emplace_back("saturating_narrow", i32_sat(wild_i64x_)); + calls.emplace_back("saturating_narrow", u32_sat(wild_u64x_)); + calls.emplace_back("saturating_narrow", u32_sat(wild_i64x_)); // SQNEG - Saturating negate negations.emplace_back("saturating_negate", -max(wild_i8x_, -127)); @@ -297,14 +298,6 @@ const ArmIntrinsic intrinsic_defs[] = { {"vrhadds", "srhadd", Int(32, 2), "rounding_halving_add", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::HalfWidth}, {"vrhaddu", "urhadd", UInt(32, 2), "rounding_halving_add", {UInt(32, 2), UInt(32, 2)}, ArmIntrinsic::HalfWidth}, - // SRHSUB, URHSUB - Halving sub with rounding - {"vrhsubs", "srhsub", Int(8, 8), "rounding_halving_sub", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::HalfWidth}, - {"vrhsubu", "urhsub", UInt(8, 8), "rounding_halving_sub", {UInt(8, 8), UInt(8, 8)}, ArmIntrinsic::HalfWidth}, - {"vrhsubs", "srhsub", Int(16, 4), "rounding_halving_sub", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::HalfWidth}, - {"vrhsubu", "urhsub", UInt(16, 4), "rounding_halving_sub", {UInt(16, 4), UInt(16, 4)}, ArmIntrinsic::HalfWidth}, - {"vrhsubs", "srhsub", Int(32, 2), "rounding_halving_sub", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::HalfWidth}, - {"vrhsubu", "urhsub", UInt(32, 2), "rounding_halving_sub", {UInt(32, 2), UInt(32, 2)}, ArmIntrinsic::HalfWidth}, - // SMIN, UMIN, FMIN - Min {"vmins", "smin", Int(8, 8), "min", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::HalfWidth}, {"vminu", "umin", UInt(8, 8), "min", {UInt(8, 8), UInt(8, 8)}, ArmIntrinsic::HalfWidth}, @@ -805,38 +798,6 @@ void CodeGen_ARM::visit(const Cast *op) { return; } } - - // If we didn't find a pattern, try rewriting the cast. - static const vector> cast_rewrites = { - // Double or triple narrowing saturating casts are better expressed as - // regular narrowing casts. - {u8_sat(wild_u32x_), u8_sat(u16_sat(wild_u32x_))}, - {u8_sat(wild_i32x_), u8_sat(i16_sat(wild_i32x_))}, - {u8_sat(wild_f32x_), u8_sat(i16_sat(wild_f32x_))}, - {i8_sat(wild_u32x_), i8_sat(u16_sat(wild_u32x_))}, - {i8_sat(wild_i32x_), i8_sat(i16_sat(wild_i32x_))}, - {i8_sat(wild_f32x_), i8_sat(i16_sat(wild_f32x_))}, - {u16_sat(wild_u64x_), u16_sat(u32_sat(wild_u64x_))}, - {u16_sat(wild_i64x_), u16_sat(i32_sat(wild_i64x_))}, - {u16_sat(wild_f64x_), u16_sat(i32_sat(wild_f64x_))}, - {i16_sat(wild_u64x_), i16_sat(u32_sat(wild_u64x_))}, - {i16_sat(wild_i64x_), i16_sat(i32_sat(wild_i64x_))}, - {i16_sat(wild_f64x_), i16_sat(i32_sat(wild_f64x_))}, - {u8_sat(wild_u64x_), u8_sat(u16_sat(u32_sat(wild_u64x_)))}, - {u8_sat(wild_i64x_), u8_sat(i16_sat(i32_sat(wild_i64x_)))}, - {u8_sat(wild_f64x_), u8_sat(i16_sat(i32_sat(wild_f64x_)))}, - {i8_sat(wild_u64x_), i8_sat(u16_sat(u32_sat(wild_u64x_)))}, - {i8_sat(wild_i64x_), i8_sat(i16_sat(i32_sat(wild_i64x_)))}, - {i8_sat(wild_f64x_), i8_sat(i16_sat(i32_sat(wild_f64x_)))}, - }; - for (const auto &i : cast_rewrites) { - if (expr_match(i.first, op, matches)) { - Expr replacement = substitute("*", matches[0], with_lanes(i.second, op->type.lanes())); - debug(3) << "rewriting cast to: " << replacement << " from " << Expr(op) << "\n"; - value = codegen(replacement); - return; - } - } } // LLVM fptoui generates fcvtzs if src is fp16 scalar else fcvtzu. @@ -999,15 +960,21 @@ void CodeGen_ARM::visit(const Store *op) { // Declare the function std::ostringstream instr; vector arg_types; + llvm::Type *intrin_llvm_type = llvm_type_of(intrin_type); +#if LLVM_VERSION >= 150 + const bool is_opaque = llvm::PointerType::get(intrin_llvm_type, 0)->isOpaque(); +#else + const bool is_opaque = false; +#endif if (target.bits == 32) { instr << "llvm.arm.neon.vst" << num_vecs - << ".p0i8" + << (is_opaque ? ".p0" : ".p0i8") << ".v" << intrin_type.lanes() << (t.is_float() ? 'f' : 'i') << t.bits(); - arg_types = vector(num_vecs + 2, llvm_type_of(intrin_type)); + arg_types = vector(num_vecs + 2, intrin_llvm_type); arg_types.front() = i8_t->getPointerTo(); arg_types.back() = i32_t; } else { @@ -1017,10 +984,11 @@ void CodeGen_ARM::visit(const Store *op) { << intrin_type.lanes() << (t.is_float() ? 'f' : 'i') << t.bits() - << ".p0" - << (t.is_float() ? 'f' : 'i') - << t.bits(); - arg_types = vector(num_vecs + 1, llvm_type_of(intrin_type)); + << ".p0"; + if (!is_opaque) { + instr << (t.is_float() ? 'f' : 'i') << t.bits(); + } + arg_types = vector(num_vecs + 1, intrin_llvm_type); arg_types.back() = llvm_type_of(intrin_type.element_of())->getPointerTo(); } llvm::FunctionType *fn_type = FunctionType::get(llvm::Type::getVoidTy(*context), arg_types, false); @@ -1177,12 +1145,55 @@ void CodeGen_ARM::visit(const Call *op) { vector matches; for (const Pattern &pattern : calls) { if (expr_match(pattern.pattern, op, matches)) { + if (pattern.intrin.find("shift_right_narrow") != string::npos) { + // The shift_right_narrow patterns need the shift to be constant in [1, output_bits]. + const uint64_t *const_b = as_const_uint(matches[1]); + if (!const_b || *const_b == 0 || (int)*const_b > op->type.bits()) { + continue; + } + } + if (target.bits == 32 && pattern.intrin.find("shift_right") != string::npos) { + // The 32-bit ARM backend wants right shifts as negative values. + matches[1] = simplify(-cast(matches[1].type().with_code(halide_type_int), matches[1])); + } value = call_overloaded_intrin(op->type, pattern.intrin, matches); if (value) { return; } } } + + // If we didn't find a pattern, try rewriting any saturating casts. + static const vector> cast_rewrites = { + // Double or triple narrowing saturating casts are better expressed as + // combinations of single narrowing saturating casts. + {u8_sat(wild_u32x_), u8_sat(u16_sat(wild_u32x_))}, + {u8_sat(wild_i32x_), u8_sat(i16_sat(wild_i32x_))}, + {u8_sat(wild_f32x_), u8_sat(i16_sat(wild_f32x_))}, + {i8_sat(wild_u32x_), i8_sat(u16_sat(wild_u32x_))}, + {i8_sat(wild_i32x_), i8_sat(i16_sat(wild_i32x_))}, + {i8_sat(wild_f32x_), i8_sat(i16_sat(wild_f32x_))}, + {u16_sat(wild_u64x_), u16_sat(u32_sat(wild_u64x_))}, + {u16_sat(wild_i64x_), u16_sat(i32_sat(wild_i64x_))}, + {u16_sat(wild_f64x_), u16_sat(i32_sat(wild_f64x_))}, + {i16_sat(wild_u64x_), i16_sat(u32_sat(wild_u64x_))}, + {i16_sat(wild_i64x_), i16_sat(i32_sat(wild_i64x_))}, + {i16_sat(wild_f64x_), i16_sat(i32_sat(wild_f64x_))}, + {u8_sat(wild_u64x_), u8_sat(u16_sat(u32_sat(wild_u64x_)))}, + {u8_sat(wild_i64x_), u8_sat(i16_sat(i32_sat(wild_i64x_)))}, + {u8_sat(wild_f64x_), u8_sat(i16_sat(i32_sat(wild_f64x_)))}, + {i8_sat(wild_u64x_), i8_sat(u16_sat(u32_sat(wild_u64x_)))}, + {i8_sat(wild_i64x_), i8_sat(i16_sat(i32_sat(wild_i64x_)))}, + {i8_sat(wild_f64x_), i8_sat(i16_sat(i32_sat(wild_f64x_)))}, + }; + for (const auto &i : cast_rewrites) { + if (expr_match(i.first, op, matches)) { + Expr replacement = substitute("*", matches[0], with_lanes(i.second, op->type.lanes())); + debug(3) << "rewriting cast to: " << replacement << " from " << Expr(op) << "\n"; + value = codegen(replacement); + return; + } + } } if (target.has_feature(Target::ARMFp16)) { @@ -1393,7 +1404,7 @@ Type CodeGen_ARM::upgrade_type_for_storage(const Type &t) const { return CodeGen_Posix::upgrade_type_for_storage(t); } -string CodeGen_ARM::mcpu() const { +string CodeGen_ARM::mcpu_target() const { if (target.bits == 32) { if (target.has_feature(Target::ARMv7s)) { return "swift"; @@ -1411,6 +1422,10 @@ string CodeGen_ARM::mcpu() const { } } +string CodeGen_ARM::mcpu_tune() const { + return mcpu_target(); +} + string CodeGen_ARM::mattrs() const { if (target.bits == 32) { if (target.has_feature(Target::ARMv7s)) { diff --git a/src/CodeGen_C.cpp b/src/CodeGen_C.cpp index 02b888f5b981..c5b64f0610f5 100644 --- a/src/CodeGen_C.cpp +++ b/src/CodeGen_C.cpp @@ -191,11 +191,11 @@ template inline T halide_cpp_min(const T &a, const T &b) {return (a < b) ? a : b;} template -inline void halide_unused(const T&) {} +inline void halide_maybe_unused(const T&) {} template const B &return_second(const A &a, const B &b) { - halide_unused(a); + halide_maybe_unused(a); return b; } @@ -1523,7 +1523,7 @@ void CodeGen_C::emit_argv_wrapper(const std::string &function_name, void CodeGen_C::emit_metadata_getter(const std::string &function_name, const std::vector &args, - const std::map &metadata_name_map) { + const MetadataNameMap &metadata_name_map) { if (is_header_or_extern_decl()) { stream << "\nHALIDE_FUNCTION_ATTRS\nconst struct halide_filter_metadata_t *" << function_name << "_metadata();\n"; return; @@ -1751,6 +1751,7 @@ void CodeGen_C::compile(const Module &input) { stream << "\n"; if (!is_header_or_extern_decl()) { +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE // Emit any external-code blobs that are C++. for (const ExternalCode &code_blob : input.external_code()) { if (code_blob.is_c_plus_plus_source()) { @@ -1762,6 +1763,7 @@ void CodeGen_C::compile(const Module &input) { stream << "\n"; } } +#endif add_vector_typedefs(type_info.vector_types_used); @@ -1798,7 +1800,7 @@ void CodeGen_C::compile(const Module &input) { } } -void CodeGen_C::compile(const LoweredFunc &f, const std::map &metadata_name_map) { +void CodeGen_C::compile(const LoweredFunc &f, const MetadataNameMap &metadata_name_map) { // Don't put non-external function declarations in headers. if (is_header_or_extern_decl() && f.linkage == LinkageType::Internal) { return; @@ -1873,9 +1875,9 @@ void CodeGen_C::compile(const LoweredFunc &f, const std::map(__user_context)" : "nullptr") << ";\n"; - if (target.has_feature(Target::NoAsserts)) { - stream << get_indent() << "halide_unused(_ucon);"; - } + // Always declare it unused, since this could be a generated closure that doesn't + // use _ucon at all, regardless of NoAsserts. + stream << get_indent() << "halide_maybe_unused(_ucon);\n"; // Emit the body print(f.body); @@ -2069,6 +2071,10 @@ void CodeGen_C::visit(const Cast *op) { id = print_cast_expr(op->type, op->value); } +void CodeGen_C::visit(const Reinterpret *op) { + id = print_assignment(op->type, print_reinterpret(op->type, op->value)); +} + void CodeGen_C::visit_binop(Type t, const Expr &a, const Expr &b, const char *op) { string sa = print_expr(a); string sb = print_expr(b); @@ -2292,9 +2298,6 @@ void CodeGen_C::visit(const Call *op) { } else if (op->is_intrinsic(Call::bitwise_not)) { internal_assert(op->args.size() == 1); rhs << "~" << print_expr(op->args[0]); - } else if (op->is_intrinsic(Call::reinterpret)) { - internal_assert(op->args.size() == 1); - rhs << print_reinterpret(op->type, op->args[0]); } else if (op->is_intrinsic(Call::shift_left)) { internal_assert(op->args.size() == 2); if (op->args[1].type().is_uint()) { @@ -2555,6 +2558,8 @@ void CodeGen_C::visit(const Call *op) { user_error << "Signed integer overflow occurred during constant-folding. Signed" " integer overflow for int32 and int64 is undefined behavior in" " Halide.\n"; + } else if (op->is_intrinsic(Call::undef)) { + user_error << "undef not eliminated before code generation. Please report this as a Halide bug.\n"; } else if (op->is_intrinsic(Call::prefetch)) { user_assert((op->args.size() == 4) && is_const_one(op->args[2])) << "Only prefetch of 1 cache line is supported in C backend.\n"; @@ -2750,7 +2755,7 @@ void CodeGen_C::visit(const Let *op) { std::string name = print_name(op->name); stream << get_indent() << "auto " << name << " = " << id_value << ";\n"; - stream << get_indent() << "halide_unused(" << name << ");\n"; + stream << get_indent() << "halide_maybe_unused(" << name << ");\n"; } else { Expr new_var = Variable::make(op->value.type(), id_value); body = substitute(op->name, new_var, body); @@ -2845,7 +2850,7 @@ void CodeGen_C::visit(const LetStmt *op) { std::string name = print_name(op->name); stream << get_indent() << "auto " << name << " = " << id_value << ";\n"; - stream << get_indent() << "halide_unused(" << name << ");\n"; + stream << get_indent() << "halide_maybe_unused(" << name << ");\n"; } else { Expr new_var = Variable::make(op->value.type(), id_value); body = substitute(op->name, new_var, body); @@ -2862,7 +2867,7 @@ void CodeGen_C::create_assertion(const string &id_cond, const Expr &message) { << "Assertion result is not an int: " << message; if (target.has_feature(Target::NoAsserts)) { - stream << get_indent() << "halide_unused(" << id_cond << ");\n"; + stream << get_indent() << "halide_maybe_unused(" << id_cond << ");\n"; return; } @@ -3152,7 +3157,7 @@ void CodeGen_C::visit(const Evaluate *op) { return; } string id = print_expr(op->value); - stream << get_indent() << "halide_unused(" << id << ");\n"; + stream << get_indent() << "halide_maybe_unused(" << id << ");\n"; } void CodeGen_C::visit(const Shuffle *op) { @@ -3261,9 +3266,10 @@ extern "C" { HALIDE_FUNCTION_ATTRS int test1(struct halide_buffer_t *_buf_buffer, float _alpha, int32_t _beta, void const *__user_context) { void * const _ucon = const_cast(__user_context); + halide_maybe_unused(_ucon); auto *_0 = _halide_buffer_get_host(_buf_buffer); auto _buf = _0; - halide_unused(_buf); + halide_maybe_unused(_buf); { int64_t _1 = 43; int64_t _2 = _1 * _beta; diff --git a/src/CodeGen_C.h b/src/CodeGen_C.h index 7e3544b36d79..9c06d4bb5630 100644 --- a/src/CodeGen_C.h +++ b/src/CodeGen_C.h @@ -65,7 +65,7 @@ class CodeGen_C : public IRPrinter { /** Emit a declaration. */ // @{ - virtual void compile(const LoweredFunc &func, const std::map &metadata_name_map); + virtual void compile(const LoweredFunc &func, const MetadataNameMap &metadata_name_map); virtual void compile(const Buffer<> &buffer); // @} @@ -196,6 +196,7 @@ class CodeGen_C : public IRPrinter { void visit(const StringImm *) override; void visit(const FloatImm *) override; void visit(const Cast *) override; + void visit(const Reinterpret *) override; void visit(const Add *) override; void visit(const Sub *) override; void visit(const Mul *) override; @@ -270,7 +271,7 @@ class CodeGen_C : public IRPrinter { const std::vector &args); void emit_metadata_getter(const std::string &function_name, const std::vector &args, - const std::map &metadata_name_map); + const MetadataNameMap &metadata_name_map); }; } // namespace Internal diff --git a/src/CodeGen_D3D12Compute_Dev.cpp b/src/CodeGen_D3D12Compute_Dev.cpp index b1e99626e53b..7d52c661df04 100644 --- a/src/CodeGen_D3D12Compute_Dev.cpp +++ b/src/CodeGen_D3D12Compute_Dev.cpp @@ -3,7 +3,6 @@ #include #include -#include "CodeGen_C.h" #include "CodeGen_D3D12Compute_Dev.h" #include "CodeGen_GPU_Dev.h" #include "CodeGen_Internal.h" @@ -62,10 +61,10 @@ class CodeGen_D3D12Compute_Dev : public CodeGen_GPU_Dev { protected: friend struct StoragePackUnpack; - class CodeGen_D3D12Compute_C : public CodeGen_C { + class CodeGen_D3D12Compute_C : public CodeGen_GPU_C { public: CodeGen_D3D12Compute_C(std::ostream &s, const Target &t) - : CodeGen_C(s, t) { + : CodeGen_GPU_C(s, t) { integer_suffix_style = IntegerSuffixStyle::HLSL; } void add_kernel(Stmt stmt, @@ -88,7 +87,7 @@ class CodeGen_D3D12Compute_Dev : public CodeGen_GPU_Dev { std::string print_assignment(Type t, const std::string &rhs) override; - using CodeGen_C::visit; + using CodeGen_GPU_C::visit; void visit(const Evaluate *op) override; void visit(const Min *) override; void visit(const Max *) override; @@ -303,7 +302,7 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const For *loop) { if (!is_gpu_var(loop->name)) { user_assert(loop->for_type != ForType::Parallel) << "Cannot use parallel loops inside D3D12Compute kernel\n"; - CodeGen_C::visit(loop); + CodeGen_GPU_C::visit(loop); return; } @@ -380,7 +379,7 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Call *op) { // directly. stream << "pow(" << print_expr(op->args[0]) << ", " << print_expr(op->args[1]) << ")"; } else { - CodeGen_C::visit(op); + CodeGen_GPU_C::visit(op); } } @@ -815,7 +814,7 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Free *op) { string CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::print_assignment(Type type, const string &rhs) { string rhs_modified = print_reinforced_cast(type, rhs); - return CodeGen_C::print_assignment(type, rhs_modified); + return CodeGen_GPU_C::print_assignment(type, rhs_modified); } string CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::print_vanilla_cast(Type type, const string &value_expr) { @@ -964,7 +963,7 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const FloatImm *op) // have seen division-by-zero shader warnings, and we postulated that it // could be indirectly related to compiler assumptions on signed integer // overflow when float_from_bits() is called, but we don't know for sure - return CodeGen_C::visit(op); + return CodeGen_GPU_C::visit(op); } void CodeGen_D3D12Compute_Dev::add_kernel(Stmt s, @@ -1272,7 +1271,7 @@ void CodeGen_D3D12Compute_Dev::init_module() { "\n" << "\n"; - src_stream << "#define halide_unused(x) (void)(x)\n"; + src_stream << "#define halide_maybe_unused(x) (void)(x)\n"; // Write out the Halide math functions. src_stream diff --git a/src/CodeGen_GPU_Dev.cpp b/src/CodeGen_GPU_Dev.cpp index 1f29f01fc1ad..3606ae41e3c9 100644 --- a/src/CodeGen_GPU_Dev.cpp +++ b/src/CodeGen_GPU_Dev.cpp @@ -116,8 +116,7 @@ class ScalarizePredicatedLoadStore : public IRMutator { mutate(extract_lane(s->index, ln)), s->param, const_true(), - // TODO: alignment needs to be changed - s->alignment))); + s->alignment + ln))); } return Block::make(scalar_stmts); } else { @@ -127,12 +126,23 @@ class ScalarizePredicatedLoadStore : public IRMutator { Expr visit(const Load *op) override { if (!is_const_one(op->predicate)) { - Expr load_expr = Load::make(op->type, op->name, op->index, op->image, - op->param, const_true(op->type.lanes()), op->alignment); - Expr pred_load = Call::make(load_expr.type(), - Call::if_then_else, - {op->predicate, load_expr}, - Internal::Call::PureIntrinsic); + std::vector lane_values; + for (int ln = 0; ln < op->type.lanes(); ln++) { + Expr load_expr = Load::make(op->type.element_of(), + op->name, + extract_lane(op->index, ln), + op->image, + op->param, + const_true(), + op->alignment + ln); + lane_values.push_back(Call::make(load_expr.type(), + Call::if_then_else, + {extract_lane(op->predicate, ln), + load_expr, + make_zero(op->type.element_of())}, + Internal::Call::PureIntrinsic)); + } + Expr pred_load = Shuffle::make_concat(lane_values); return pred_load; } else { return op; @@ -147,5 +157,47 @@ Stmt CodeGen_GPU_Dev::scalarize_predicated_loads_stores(Stmt &s) { return sps.mutate(s); } +void CodeGen_GPU_C::visit(const Shuffle *op) { + if (op->type.is_scalar()) { + CodeGen_C::visit(op); + } else { + internal_assert(!op->vectors.empty()); + for (size_t i = 1; i < op->vectors.size(); i++) { + internal_assert(op->vectors[0].type() == op->vectors[i].type()); + } + internal_assert(op->type.lanes() == (int)op->indices.size()); + const int max_index = (int)(op->vectors[0].type().lanes() * op->vectors.size()); + for (int i : op->indices) { + internal_assert(i >= 0 && i < max_index); + } + + std::vector vecs; + for (const Expr &v : op->vectors) { + vecs.push_back(print_expr(v)); + } + + std::string src = vecs[0]; + std::ostringstream rhs; + std::string storage_name = unique_name('_'); + if (vector_declaration_style == VectorDeclarationStyle::OpenCLSyntax) { + rhs << "(" << print_type(op->type) << ")("; + } else { + rhs << "{"; + } + for (int i : op->indices) { + rhs << vecs[i]; + if (i < (int)(op->indices.size() - 1)) { + rhs << ", "; + } + } + if (vector_declaration_style == VectorDeclarationStyle::OpenCLSyntax) { + rhs << ")"; + } else { + rhs << "}"; + } + print_assignment(op->type, rhs.str()); + } +} + } // namespace Internal } // namespace Halide diff --git a/src/CodeGen_GPU_Dev.h b/src/CodeGen_GPU_Dev.h index 2b516fd62d13..dfbd1f58c49b 100644 --- a/src/CodeGen_GPU_Dev.h +++ b/src/CodeGen_GPU_Dev.h @@ -7,6 +7,7 @@ #include #include +#include "CodeGen_C.h" #include "DeviceArgument.h" #include "Expr.h" @@ -73,8 +74,8 @@ struct CodeGen_GPU_Dev { static Stmt scalarize_predicated_loads_stores(Stmt &s); /** An mask describing which type of memory fence to use for the gpu_thread_barrier() - * intrinsic. Not all GPUs APIs support all types. - */ + * intrinsic. Not all GPUs APIs support all types. + */ enum MemoryFenceType { None = 0, // No fence required (just a sync) Device = 1, // Device/global memory fence @@ -82,6 +83,28 @@ struct CodeGen_GPU_Dev { }; }; +/** A base class for GPU backends that require C-like shader output. + * GPU backends derive from and specialize this class. */ +class CodeGen_GPU_C : public CodeGen_C { +public: + /** OpenCL uses a different syntax than C for immediate vectors. This + enum defines which style should be used by the backend. */ + enum class VectorDeclarationStyle { + CLikeSyntax = 0, + OpenCLSyntax = 1 + }; + + CodeGen_GPU_C(std::ostream &s, Target t) + : CodeGen_C(s, t) { + } + +protected: + using CodeGen_C::visit; + void visit(const Shuffle *op) override; + + VectorDeclarationStyle vector_declaration_style = VectorDeclarationStyle::CLikeSyntax; +}; + } // namespace Internal } // namespace Halide diff --git a/src/CodeGen_Hexagon.cpp b/src/CodeGen_Hexagon.cpp index a32bca98ff7d..a0fd9a850612 100644 --- a/src/CodeGen_Hexagon.cpp +++ b/src/CodeGen_Hexagon.cpp @@ -42,7 +42,8 @@ class CodeGen_Hexagon : public CodeGen_Posix { void init_module() override; - std::string mcpu() const override; + std::string mcpu_target() const override; + std::string mcpu_tune() const override; std::string mattrs() const override; int isa_version; bool use_soft_float_abi() const override; @@ -126,6 +127,9 @@ class CodeGen_Hexagon : public CodeGen_Posix { /** Generate a LUT (8/16 bit, max_index < 256) lookup using vlut instructions. */ llvm::Value *vlut256(llvm::Value *lut, llvm::Value *indices, int min_index = 0, int max_index = 255); + + /** Wrapper to create a vector populated with a constant value in each lane. */ + Value *create_vector(llvm::Type *ty, int val); }; CodeGen_Hexagon::CodeGen_Hexagon(const Target &t) @@ -994,8 +998,8 @@ llvm::Function *CodeGen_Hexagon::define_hvx_intrinsic(llvm::Function *intrin, Value *CodeGen_Hexagon::create_bitcast(Value *v, llvm::Type *ty) { if (BitCastInst *c = dyn_cast(v)) { return create_bitcast(c->getOperand(0), ty); - } else if (isa(v)) { - return UndefValue::get(ty); + } else if (isa(v)) { + return PoisonValue::get(ty); } else if (v->getType() != ty) { v = builder->CreateBitCast(v, ty); } @@ -1174,7 +1178,7 @@ Value *CodeGen_Hexagon::shuffle_vectors(Value *a, Value *b, i -= a_elements; } } - return shuffle_vectors(b, UndefValue::get(b->getType()), shifted_indices); + return shuffle_vectors(b, shifted_indices); } // Try to rewrite shuffles that only access the elements of a. @@ -1613,7 +1617,7 @@ Value *CodeGen_Hexagon::vdelta(Value *lut, const vector &indices) { return vlut(lut, indices); } -Value *create_vector(llvm::Type *ty, int val) { +Value *CodeGen_Hexagon::create_vector(llvm::Type *ty, int val) { llvm::Type *scalar_ty = ty->getScalarType(); Constant *value = ConstantInt::get(scalar_ty, val); return ConstantVector::getSplat(element_count(get_vector_num_elements(ty)), value); @@ -1788,7 +1792,7 @@ Value *CodeGen_Hexagon::call_intrin(llvm::Type *result_type, const string &name, fn, std::move(args)); } -string CodeGen_Hexagon::mcpu() const { +string CodeGen_Hexagon::mcpu_target() const { if (target.has_feature(Halide::Target::HVX_v66)) { return "hexagonv66"; } else if (target.has_feature(Halide::Target::HVX_v65)) { @@ -1798,6 +1802,10 @@ string CodeGen_Hexagon::mcpu() const { } } +string CodeGen_Hexagon::mcpu_tune() const { + return mcpu_target(); +} + string CodeGen_Hexagon::mattrs() const { std::stringstream attrs; attrs << "+hvx-length128b"; diff --git a/src/CodeGen_Internal.cpp b/src/CodeGen_Internal.cpp index f7f7b2917991..61ddd8c3cbe9 100644 --- a/src/CodeGen_Internal.cpp +++ b/src/CodeGen_Internal.cpp @@ -16,7 +16,8 @@ using std::string; using namespace llvm; -llvm::Type *llvm_type_of(LLVMContext *c, Halide::Type t) { +llvm::Type *llvm_type_of(LLVMContext *c, Halide::Type t, + int effective_vscale) { if (t.lanes() == 1) { if (t.is_float() && !t.is_bfloat()) { switch (t.bits()) { @@ -36,18 +37,27 @@ llvm::Type *llvm_type_of(LLVMContext *c, Halide::Type t) { return llvm::Type::getIntNTy(*c, t.bits()); } } else { - llvm::Type *element_type = llvm_type_of(c, t.element_of()); - return get_vector_type(element_type, t.lanes()); - } -} - -int get_vector_num_elements(llvm::Type *t) { - if (t->isVectorTy()) { - auto *vt = dyn_cast(t); - internal_assert(vt) << "Called get_vector_num_elements on a scalable vector type\n"; - return vt->getNumElements(); - } else { - return 1; + llvm::Type *element_type = llvm_type_of(c, t.element_of(), 0); + bool scalable = false; + int lanes = t.lanes(); + if (effective_vscale != 0) { + int total_bits = t.bits() * t.lanes(); + scalable = ((total_bits % effective_vscale) == 0); + if (scalable) { + lanes /= effective_vscale; + } else { + // TODO(zvookin): This error indicates that the requested number of vector lanes + // is not expressible exactly via vscale. This will be fairly unusual unless + // non-power of two, or very short, vector sizes are used in a schedule. + // It is made an error, instead of passing the fixed non-vscale vector type to LLVM, + // to catch the case early while developing vscale backends. + // We may need to change this to allow the case so if one hits this error in situation + // where it should pass through a fixed width vector type, please discuss. + internal_error << "Failed to make vscale vector type with bits " << t.bits() << " lanes " << t.lanes() + << " effective_vscale " << effective_vscale << " total_bits " << total_bits << "\n"; + } + } + return get_vector_type(element_type, lanes, scalable); } } @@ -63,8 +73,8 @@ llvm::ElementCount element_count(int e) { return llvm::ElementCount::getFixed(e); } -llvm::Type *get_vector_type(llvm::Type *t, int n) { - return VectorType::get(t, element_count(n)); +llvm::Type *get_vector_type(llvm::Type *t, int n, bool scalable) { + return VectorType::get(t, n, scalable); } // Returns true if the given function name is one of the Halide runtime @@ -590,16 +600,15 @@ bool get_md_string(llvm::Metadata *value, std::string &result) { return false; } -void get_target_options(const llvm::Module &module, llvm::TargetOptions &options, std::string &mcpu, std::string &mattrs) { +void get_target_options(const llvm::Module &module, llvm::TargetOptions &options) { bool use_soft_float_abi = false; get_md_bool(module.getModuleFlag("halide_use_soft_float_abi"), use_soft_float_abi); - get_md_string(module.getModuleFlag("halide_mcpu"), mcpu); - get_md_string(module.getModuleFlag("halide_mattrs"), mattrs); std::string mabi; get_md_string(module.getModuleFlag("halide_mabi"), mabi); bool use_pic = true; get_md_bool(module.getModuleFlag("halide_use_pic"), use_pic); + // FIXME: can this be migrated into `set_function_attributes_from_halide_target_options()`? bool per_instruction_fast_math_flags = false; get_md_bool(module.getModuleFlag("halide_per_instruction_fast_math_flags"), per_instruction_fast_math_flags); @@ -611,11 +620,6 @@ void get_target_options(const llvm::Module &module, llvm::TargetOptions &options options.HonorSignDependentRoundingFPMathOption = !per_instruction_fast_math_flags; options.NoZerosInBSS = false; options.GuaranteedTailCallOpt = false; -#if LLVM_VERSION >= 130 - // nothing -#else - options.StackAlignmentOverride = 0; -#endif options.FunctionSections = true; options.UseInitArray = true; options.FloatABIType = @@ -634,9 +638,14 @@ void clone_target_options(const llvm::Module &from, llvm::Module &to) { to.addModuleFlag(llvm::Module::Warning, "halide_use_soft_float_abi", use_soft_float_abi ? 1 : 0); } - std::string mcpu; - if (get_md_string(from.getModuleFlag("halide_mcpu"), mcpu)) { - to.addModuleFlag(llvm::Module::Warning, "halide_mcpu", llvm::MDString::get(context, mcpu)); + std::string mcpu_target; + if (get_md_string(from.getModuleFlag("halide_mcpu_target"), mcpu_target)) { + to.addModuleFlag(llvm::Module::Warning, "halide_mcpu_target", llvm::MDString::get(context, mcpu_target)); + } + + std::string mcpu_tune; + if (get_md_string(from.getModuleFlag("halide_mcpu_tune"), mcpu_tune)) { + to.addModuleFlag(llvm::Module::Warning, "halide_mcpu_tune", llvm::MDString::get(context, mcpu_tune)); } std::string mattrs; @@ -662,9 +671,7 @@ std::unique_ptr make_target_machine(const llvm::Module &mod internal_assert(llvm_target) << "Could not create LLVM target for " << triple.str() << "\n"; llvm::TargetOptions options; - std::string mcpu = ""; - std::string mattrs = ""; - get_target_options(module, options, mcpu, mattrs); + get_target_options(module, options); bool use_pic = true; get_md_bool(module.getModuleFlag("halide_use_pic"), use_pic); @@ -673,7 +680,7 @@ std::unique_ptr make_target_machine(const llvm::Module &mod get_md_bool(module.getModuleFlag("halide_use_large_code_model"), use_large_code_model); auto *tm = llvm_target->createTargetMachine(module.getTargetTriple(), - mcpu, mattrs, + /*CPU target=*/"", /*Features=*/"", options, use_pic ? llvm::Reloc::PIC_ : llvm::Reloc::Static, use_large_code_model ? llvm::CodeModel::Large : llvm::CodeModel::Small, @@ -681,20 +688,43 @@ std::unique_ptr make_target_machine(const llvm::Module &mod return std::unique_ptr(tm); } -void set_function_attributes_for_target(llvm::Function *fn, const Target &t) { +void set_function_attributes_from_halide_target_options(llvm::Function &fn) { + llvm::Module &module = *fn.getParent(); + + std::string mcpu_target, mcpu_tune, mattrs, vscale_range; + get_md_string(module.getModuleFlag("halide_mcpu_target"), mcpu_target); + get_md_string(module.getModuleFlag("halide_mcpu_tune"), mcpu_tune); + get_md_string(module.getModuleFlag("halide_mattrs"), mattrs); + get_md_string(module.getModuleFlag("halide_vscale_range"), vscale_range); + + fn.addFnAttr("target-cpu", mcpu_target); + fn.addFnAttr("tune-cpu", mcpu_tune); + fn.addFnAttr("target-features", mattrs); + + // Halide-generated IR is not exception-safe. + // No exception should unwind out of Halide functions. + // No exception should be thrown within Halide functions. + // All functions called by the Halide function must not unwind. + fn.setDoesNotThrow(); + + // Side-effect-free loops are undefined. + // But asserts and external calls *might* abort. + fn.setMustProgress(); + // Turn off approximate reciprocals for division. It's too // inaccurate even for us. - fn->addFnAttr("reciprocal-estimates", "none"); + fn.addFnAttr("reciprocal-estimates", "none"); + + // If a fixed vscale is asserted, add it as an attribute on the function. + if (!vscale_range.empty()) { + fn.addFnAttr("vscale_range", vscale_range); + } } void embed_bitcode(llvm::Module *M, const string &halide_command) { // Save llvm.compiler.used and remote it. SmallVector used_array; -#if LLVM_VERSION >= 130 SmallVector used_globals; -#else - SmallPtrSet used_globals; -#endif llvm::Type *used_element_type = llvm::Type::getInt8Ty(M->getContext())->getPointerTo(0); GlobalVariable *used = collectUsedGlobalVariables(*M, used_globals, true); for (auto *GV : used_globals) { @@ -765,5 +795,34 @@ void embed_bitcode(llvm::Module *M, const string &halide_command) { } } +Expr lower_concat_bits(const Call *op) { + internal_assert(op->is_intrinsic(Call::concat_bits)); + internal_assert(!op->args.empty()); + + Expr result = make_zero(op->type); + int shift = 0; + for (const Expr &e : op->args) { + result = result | (cast(result.type(), e) << shift); + shift += e.type().bits(); + } + return result; +} + +Expr lower_extract_bits(const Call *op) { + Expr e = op->args[0]; + // Do a shift-and-cast as a uint, which will zero-fill any out-of-range + // bits for us. + if (!e.type().is_uint()) { + e = reinterpret(e.type().with_code(halide_type_uint), e); + } + e = e >> op->args[1]; + e = cast(op->type.with_code(halide_type_uint), e); + if (op->type != e.type()) { + e = reinterpret(op->type, e); + } + e = simplify(e); + return e; +} + } // namespace Internal } // namespace Halide diff --git a/src/CodeGen_Internal.h b/src/CodeGen_Internal.h index 3fe1b8b696f5..8c1a0e1994eb 100644 --- a/src/CodeGen_Internal.h +++ b/src/CodeGen_Internal.h @@ -37,12 +37,14 @@ struct Target; namespace Internal { -/** Get the llvm type equivalent to a given halide type */ -llvm::Type *llvm_type_of(llvm::LLVMContext *context, Halide::Type t); - -/** Get the number of elements in an llvm vector type, or return 1 if - * it's not a vector type. */ -int get_vector_num_elements(llvm::Type *); +/** Get the llvm type equivalent to a given halide type. If + * effective_vscale is nonzero and the type is a vector type with lanes + * a multiple of effective_vscale, a scalable vector type is generated + * with total lanes divided by effective_vscale. That is a scalable + * vector intended to be used with a fixed vscale of effective_vscale. + */ +llvm::Type *llvm_type_of(llvm::LLVMContext *context, Halide::Type t, + int effective_vscale); /** Get the scalar type of an llvm vector type. Returns the argument * if it's not a vector type. */ @@ -50,7 +52,7 @@ llvm::Type *get_vector_element_type(llvm::Type *); llvm::ElementCount element_count(int e); -llvm::Type *get_vector_type(llvm::Type *, int); +llvm::Type *get_vector_type(llvm::Type *, int n, bool scalable = false); /** Which built-in functions require a user-context first argument? */ bool function_takes_user_context(const std::string &name); @@ -92,8 +94,14 @@ Expr lower_signed_shift_right(const Expr &a, const Expr &b); /** Reduce a mux intrinsic to a select tree */ Expr lower_mux(const Call *mux); -/** Given an llvm::Module, set llvm:TargetOptions, cpu and attr information */ -void get_target_options(const llvm::Module &module, llvm::TargetOptions &options, std::string &mcpu, std::string &mattrs); +/** Reduce bit extraction and concatenation to bit ops */ +///@{ +Expr lower_extract_bits(const Call *c); +Expr lower_concat_bits(const Call *c); +///@} + +/** Given an llvm::Module, set llvm:TargetOptions information */ +void get_target_options(const llvm::Module &module, llvm::TargetOptions &options); /** Given two llvm::Modules, clone target options from one to the other */ void clone_target_options(const llvm::Module &from, llvm::Module &to); @@ -101,8 +109,8 @@ void clone_target_options(const llvm::Module &from, llvm::Module &to); /** Given an llvm::Module, get or create an llvm:TargetMachine */ std::unique_ptr make_target_machine(const llvm::Module &module); -/** Set the appropriate llvm Function attributes given a Target. */ -void set_function_attributes_for_target(llvm::Function *, const Target &); +/** Set the appropriate llvm Function attributes given the Halide Target. */ +void set_function_attributes_from_halide_target_options(llvm::Function &); /** Save a copy of the llvm IR currently represented by the module as * data in the __LLVM,__bitcode section. Emulates clang's diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 22945cc4361b..9e6138b6e4a1 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -23,7 +23,6 @@ #include "LLVM_Runtime_Linker.h" #include "Lerp.h" #include "LowerParallelTasks.h" -#include "MatlabWrapper.h" #include "Pipeline.h" #include "Simplify.h" #include "Util.h" @@ -147,21 +146,12 @@ namespace { llvm::Value *CreateConstGEP1_32(IRBuilderBase *builder, llvm::Type *gep_type, Value *ptr, unsigned index) { -#if LLVM_VERSION >= 130 return builder->CreateConstGEP1_32(gep_type, ptr, index); -#else - (void)gep_type; - return builder->CreateConstGEP1_32(ptr, index); -#endif } llvm::Value *CreateInBoundsGEP(IRBuilderBase *builder, llvm::Type *gep_type, Value *ptr, ArrayRef index_list) { -#if LLVM_VERSION >= 130 return builder->CreateInBoundsGEP(gep_type, ptr, index_list); -#else - return builder->CreateInBoundsGEP(ptr, index_list); -#endif } // Get the LLVM linkage corresponding to a Halide linkage type. @@ -230,12 +220,14 @@ CodeGen_LLVM::CodeGen_LLVM(const Target &t) destructor_block(nullptr), strict_float(t.has_feature(Target::StrictFloat)), - llvm_large_code_model(t.has_feature(Target::LLVMLargeCodeModel)) { + llvm_large_code_model(t.has_feature(Target::LLVMLargeCodeModel)), + effective_vscale(0) { initialize_llvm(); } void CodeGen_LLVM::set_context(llvm::LLVMContext &context) { this->context = &context; + effective_vscale = target_vscale(); } std::unique_ptr CodeGen_LLVM::new_for_target(const Target &target, llvm::LLVMContext &context) { @@ -296,7 +288,6 @@ void CodeGen_LLVM::initialize_llvm() { #define LLVM_ASM_PRINTER(target) \ Initialize##target##AsmPrinter(); #include -#include #undef LLVM_ASM_PRINTER }); } @@ -347,6 +338,7 @@ void CodeGen_LLVM::init_module() { module = get_initial_module_for_target(target, context); } +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE void CodeGen_LLVM::add_external_code(const Module &halide_module) { for (const ExternalCode &code_blob : halide_module.external_code()) { if (code_blob.is_for_cpu_target(get_target())) { @@ -354,6 +346,7 @@ void CodeGen_LLVM::add_external_code(const Module &halide_module) { } } } +#endif CodeGen_LLVM::~CodeGen_LLVM() { delete builder; @@ -465,12 +458,17 @@ void CodeGen_LLVM::init_codegen(const std::string &name, bool any_strict_float) // Add some target specific info to the module as metadata. module->addModuleFlag(llvm::Module::Warning, "halide_use_soft_float_abi", use_soft_float_abi() ? 1 : 0); - module->addModuleFlag(llvm::Module::Warning, "halide_mcpu", MDString::get(*context, mcpu())); + module->addModuleFlag(llvm::Module::Warning, "halide_mcpu_target", MDString::get(*context, mcpu_target())); + module->addModuleFlag(llvm::Module::Warning, "halide_mcpu_tune", MDString::get(*context, mcpu_tune())); module->addModuleFlag(llvm::Module::Warning, "halide_mattrs", MDString::get(*context, mattrs())); module->addModuleFlag(llvm::Module::Warning, "halide_mabi", MDString::get(*context, mabi())); module->addModuleFlag(llvm::Module::Warning, "halide_use_pic", use_pic() ? 1 : 0); module->addModuleFlag(llvm::Module::Warning, "halide_use_large_code_model", llvm_large_code_model ? 1 : 0); module->addModuleFlag(llvm::Module::Warning, "halide_per_instruction_fast_math_flags", any_strict_float); + if (effective_vscale != 0) { + module->addModuleFlag(llvm::Module::Warning, "halide_vscale_range", + MDString::get(*context, std::to_string(effective_vscale) + ", " + std::to_string(effective_vscale))); + } // Ensure some types we need are defined halide_buffer_t_type = get_llvm_struct_type_by_name(module.get(), "struct.halide_buffer_t"); @@ -499,12 +497,6 @@ void CodeGen_LLVM::init_codegen(const std::string &name, bool any_strict_float) semaphore_t_type = get_llvm_struct_type_by_name(module.get(), "struct.halide_semaphore_t"); internal_assert(semaphore_t_type) << "Did not find halide_semaphore_t in initial module"; - - semaphore_acquire_t_type = get_llvm_struct_type_by_name(module.get(), "struct.halide_semaphore_acquire_t"); - internal_assert(semaphore_acquire_t_type) << "Did not find halide_semaphore_acquire_t in initial module"; - - parallel_task_t_type = get_llvm_struct_type_by_name(module.get(), "struct.halide_parallel_task_t"); - internal_assert(parallel_task_t_type) << "Did not find halide_parallel_task_t in initial module"; } std::unique_ptr CodeGen_LLVM::compile(const Module &input) { @@ -513,7 +505,9 @@ std::unique_ptr CodeGen_LLVM::compile(const Module &input) { internal_assert(module && context && builder) << "The CodeGen_LLVM subclass should have made an initial module before calling CodeGen_LLVM::compile\n"; +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE add_external_code(input); +#endif // Generate the code for this module. debug(1) << "Generating llvm bitcode...\n"; @@ -539,7 +533,7 @@ std::unique_ptr CodeGen_LLVM::compile(const Module &input) { } FunctionType *func_t = FunctionType::get(i32_t, arg_types, false); function = llvm::Function::Create(func_t, llvm_linkage(f.linkage), names.extern_name, module.get()); - set_function_attributes_for_target(function, target); + set_function_attributes_from_halide_target_options(*function); // Mark the buffer args as no alias and save indication for add_argv_wrapper if needed std::vector buffer_args(f.args.size()); @@ -557,14 +551,10 @@ std::unique_ptr CodeGen_LLVM::compile(const Module &input) { // If the Func is externally visible, also create the argv wrapper and metadata. // (useful for calling from JIT and other machine interfaces). if (f.linkage == LinkageType::ExternalPlusArgv || f.linkage == LinkageType::ExternalPlusMetadata) { - llvm::Function *wrapper = add_argv_wrapper(function, names.argv_name, false, buffer_args); + add_argv_wrapper(function, names.argv_name, false, buffer_args); if (f.linkage == LinkageType::ExternalPlusMetadata) { - llvm::Function *metadata_getter = embed_metadata_getter(names.metadata_name, - names.simple_name, f.args, input.get_metadata_name_map()); - - if (target.has_feature(Target::Matlab)) { - define_matlab_wrapper(module.get(), wrapper, metadata_getter); - } + embed_metadata_getter(names.metadata_name, + names.simple_name, f.args, input.get_metadata_name_map()); } } } @@ -584,6 +574,8 @@ std::unique_ptr CodeGen_LLVM::compile(const Module &input) { } std::unique_ptr CodeGen_LLVM::finish_codegen() { + llvm::for_each(*module, set_function_attributes_from_halide_target_options); + // Verify the module is ok internal_assert(!verifyModule(*module, &llvm::errs())); debug(2) << "Done generating llvm bitcode\n"; @@ -982,7 +974,7 @@ llvm::Function *CodeGen_LLVM::add_argv_wrapper(llvm::Function *fn, llvm::Function *CodeGen_LLVM::embed_metadata_getter(const std::string &metadata_name, const std::string &function_name, const std::vector &args, - const std::map &metadata_name_map) { + const MetadataNameMap &metadata_name_map) { Constant *zero = ConstantInt::get(i32_t, 0); const int num_args = (int)args.size(); @@ -1091,7 +1083,7 @@ llvm::Function *CodeGen_LLVM::embed_metadata_getter(const std::string &metadata_ } llvm::Type *CodeGen_LLVM::llvm_type_of(const Type &t) const { - return Internal::llvm_type_of(context, t); + return Internal::llvm_type_of(context, t, effective_vscale); } void CodeGen_LLVM::optimize_module() { @@ -1105,14 +1097,15 @@ void CodeGen_LLVM::optimize_module() { std::unique_ptr tm = make_target_machine(*module); - // At present, we default to *enabling* LLVM loop optimization, - // unless DisableLLVMLoopOpt is set; we're going to flip this to defaulting - // to *not* enabling these optimizations (and removing the DisableLLVMLoopOpt feature). - // See https://github.com/halide/Halide/issues/4113 for more info. - // (Note that setting EnableLLVMLoopOpt always enables loop opt, regardless - // of the setting of DisableLLVMLoopOpt.) - const bool do_loop_opt = !get_target().has_feature(Target::DisableLLVMLoopOpt) || - get_target().has_feature(Target::EnableLLVMLoopOpt); + // halide_target_feature_disable_llvm_loop_opt is deprecated in Halide 15 + // (and will be removed in Halide 16). Halide 15 now defaults to disabling + // LLVM loop optimization, unless halide_target_feature_enable_llvm_loop_opt is set. + if (get_target().has_feature(Target::DisableLLVMLoopOpt)) { + user_warning << "halide_target_feature_disable_llvm_loop_opt is deprecated in Halide 15 " + "(and will be removed in Halide 16). Halide 15 now defaults to disabling " + "LLVM loop optimization, unless halide_target_feature_enable_llvm_loop_opt is set.\n"; + } + const bool do_loop_opt = get_target().has_feature(Target::EnableLLVMLoopOpt); PipelineTuningOptions pto; pto.LoopInterleaving = do_loop_opt; @@ -1126,28 +1119,21 @@ void CodeGen_LLVM::optimize_module() { // 21.04 -> 14.78 using current ToT release build. (See also https://reviews.llvm.org/rL358304) pto.ForgetAllSCEVInLoopUnroll = true; -#if LLVM_VERSION >= 130 llvm::PassBuilder pb(tm.get(), pto); -#else - llvm::PassBuilder pb(/*DebugLogging*/ false, tm.get(), pto); -#endif bool debug_pass_manager = false; // These analysis managers have to be declared in this order. -#if LLVM_VERSION >= 130 llvm::LoopAnalysisManager lam; llvm::FunctionAnalysisManager fam; llvm::CGSCCAnalysisManager cgam; llvm::ModuleAnalysisManager mam; -#else - llvm::LoopAnalysisManager lam(debug_pass_manager); - llvm::FunctionAnalysisManager fam(debug_pass_manager); - llvm::CGSCCAnalysisManager cgam(debug_pass_manager); - llvm::ModuleAnalysisManager mam(debug_pass_manager); -#endif +#if LLVM_VERSION < 140 + // If building against LLVM older than 14, explicitly specify AA pipeline. + // Not needed with LLVM14 or later, already the default. llvm::AAManager aa = pb.buildDefaultAAPipeline(); fam.registerPass([&] { return std::move(aa); }); +#endif // Register all the basic analyses with the managers. pb.registerModuleAnalyses(mam); @@ -1155,11 +1141,7 @@ void CodeGen_LLVM::optimize_module() { pb.registerFunctionAnalyses(fam); pb.registerLoopAnalyses(lam); pb.crossRegisterProxies(lam, fam, cgam, mam); -#if LLVM_VERSION >= 130 ModulePassManager mpm; -#else - ModulePassManager mpm(debug_pass_manager); -#endif #if LLVM_VERSION >= 140 using OptimizationLevel = llvm::OptimizationLevel; @@ -1189,9 +1171,13 @@ void CodeGen_LLVM::optimize_module() { } if (get_target().has_feature(Target::ASAN)) { +#if LLVM_VERSION >= 150 + // Nothing, ASanGlobalsMetadataAnalysis no longer exists +#else pb.registerPipelineStartEPCallback([&](ModulePassManager &mpm, OptimizationLevel) { mpm.addPass(RequireAnalysisPass()); }); +#endif pb.registerPipelineStartEPCallback([](ModulePassManager &mpm, OptimizationLevel) { #if LLVM_VERSION >= 140 AddressSanitizerOptions asan_options; // default values are good... @@ -1246,15 +1232,9 @@ void CodeGen_LLVM::optimize_module() { } } -#if LLVM_VERSION >= 130 if (tm) { tm->registerPassBuilderCallbacks(pb); } -#else - if (tm) { - tm->registerPassBuilderCallbacks(pb, debug_pass_manager); - } -#endif mpm = pb.buildPerModuleDefaultPipeline(level, debug_pass_manager); mpm.run(*module, mam); @@ -1328,6 +1308,9 @@ Value *CodeGen_LLVM::codegen(const Expr &e) { value = builder->CreateExtractElement(value, ConstantInt::get(i32_t, 0)); } + // Make sure fixed/vscale property of vector types match what is exepected. + value = normalize_fixed_scalable_vector_type(llvm_type_of(e.type()), value); + // TODO: skip this correctness check for bool vectors, // as eliminate_bool_vectors() will cause a discrepancy for some backends // (eg OpenCL, HVX, WASM); for now we're just ignoring the assert, but @@ -1493,6 +1476,18 @@ void CodeGen_LLVM::visit(const Cast *op) { } } +void CodeGen_LLVM::visit(const Reinterpret *op) { + Type dst = op->type; + llvm::Type *llvm_dst = llvm_type_of(dst); + value = codegen(op->value); + // Our `Reinterpret` expr directly maps to LLVM IR bitcast/ptrtoint/inttoptr + // instructions with no additional handling required: + // * bitcast between vectors and scalars is well-formed. + // * ptrtoint/inttoptr implicitly truncates/zero-extends the integer + // to match the pointer size. + value = builder->CreateBitOrPointerCast(value, llvm_dst); +} + void CodeGen_LLVM::visit(const Variable *op) { value = sym_get(op->name); } @@ -2026,15 +2021,11 @@ void CodeGen_LLVM::visit(const Load *op) { Value *load_i = codegen_dense_vector_load(op->type.with_lanes(load_lanes_i), op->name, slice_base, op->image, op->param, align, nullptr, false); - SmallVector constants; + std::vector constants; for (int j = 0; j < lanes_i; j++) { - Constant *constant = ConstantInt::get(i32_t, j * stride->value + offset); - constants.push_back(constant); + constants.push_back(j * stride->value + offset); } - Constant *constantsV = ConstantVector::get(constants); - Value *undef = UndefValue::get(load_i->getType()); - Value *shuffleInstr = builder->CreateShuffleVector(load_i, undef, constantsV); - results.push_back(shuffleInstr); + results.push_back(shuffle_vectors(load_i, constants)); } // Concat the results @@ -2061,7 +2052,7 @@ void CodeGen_LLVM::visit(const Load *op) { // Gather without generating the indices as a vector Value *ptr = codegen_buffer_pointer(op->name, op->type.element_of(), ramp->base); Value *stride = codegen(ramp->stride); - value = UndefValue::get(llvm_type_of(op->type)); + value = PoisonValue::get(llvm_type_of(op->type)); for (int i = 0; i < ramp->lanes; i++) { Value *lane = ConstantInt::get(i32_t, i); LoadInst *val = builder->CreateLoad(load_type, ptr); @@ -2075,7 +2066,7 @@ void CodeGen_LLVM::visit(const Load *op) { // loads in it, and it's all int32. // Compute the index as scalars, and then do a gather - Value *vec = UndefValue::get(llvm_type_of(op->type)); + Value *vec = PoisonValue::get(llvm_type_of(op->type)); for (int i = 0; i < op->type.lanes(); i++) { Expr idx = extract_lane(op->index, i); Value *ptr = codegen_buffer_pointer(op->name, op->type.element_of(), idx); @@ -2087,7 +2078,7 @@ void CodeGen_LLVM::visit(const Load *op) { } else { // General gathers Value *index = codegen(op->index); - Value *vec = UndefValue::get(llvm_type_of(op->type)); + Value *vec = PoisonValue::get(llvm_type_of(op->type)); for (int i = 0; i < op->type.lanes(); i++) { Value *idx = builder->CreateExtractElement(index, ConstantInt::get(i32_t, i)); Value *ptr = codegen_buffer_pointer(op->name, op->type.element_of(), idx); @@ -2122,7 +2113,7 @@ void CodeGen_LLVM::visit(const Ramp *op) { Value *base = codegen(op->base); Value *stride = codegen(op->stride); - value = UndefValue::get(llvm_type_of(op->type)); + value = PoisonValue::get(llvm_type_of(op->type)); for (int i = 0; i < op->type.lanes(); i++) { if (i > 0) { if (op->type.is_float()) { @@ -2139,11 +2130,11 @@ void CodeGen_LLVM::visit(const Ramp *op) { } llvm::Value *CodeGen_LLVM::create_broadcast(llvm::Value *v, int lanes) { - Constant *undef = UndefValue::get(get_vector_type(v->getType(), lanes)); + Constant *poison = PoisonValue::get(get_vector_type(v->getType(), lanes)); Constant *zero = ConstantInt::get(i32_t, 0); - v = builder->CreateInsertElement(undef, v, zero); + v = builder->CreateInsertElement(poison, v, zero); Constant *zeros = ConstantVector::getSplat(element_count(lanes), zero); - return builder->CreateShuffleVector(v, undef, zeros); + return builder->CreateShuffleVector(v, poison, zeros); } void CodeGen_LLVM::visit(const Broadcast *op) { @@ -2225,7 +2216,7 @@ Value *CodeGen_LLVM::interleave_vectors(const std::vector &vecs) { void CodeGen_LLVM::scalarize(const Expr &e) { llvm::Type *result_type = llvm_type_of(e.type()); - Value *result = UndefValue::get(result_type); + Value *result = PoisonValue::get(result_type); for (int i = 0; i < e.type().lanes(); i++) { Value *v = codegen(extract_lane(e, i)); @@ -2376,11 +2367,7 @@ llvm::Value *CodeGen_LLVM::codegen_dense_vector_load(const Type &type, const std Instruction *load_inst; if (vpred != nullptr) { Value *slice_mask = slice_vector(vpred, i, slice_lanes); -#if LLVM_VERSION >= 130 load_inst = builder->CreateMaskedLoad(slice_type, vec_ptr, llvm::Align(align_bytes), slice_mask); -#else - load_inst = builder->CreateMaskedLoad(vec_ptr, llvm::Align(align_bytes), slice_mask); -#endif } else { load_inst = builder->CreateAlignedLoad(slice_type, vec_ptr, llvm::Align(align_bytes)); } @@ -2467,20 +2454,11 @@ void CodeGen_LLVM::codegen_atomic_rmw(const Store *op) { Value *ptr = codegen_buffer_pointer(op->name, op->value.type(), op->index); -#if LLVM_VERSION >= 130 if (value_type.is_float()) { builder->CreateAtomicRMW(AtomicRMWInst::FAdd, ptr, val, llvm::MaybeAlign(), AtomicOrdering::Monotonic); } else { builder->CreateAtomicRMW(AtomicRMWInst::Add, ptr, val, llvm::MaybeAlign(), AtomicOrdering::Monotonic); } -#else - // llvm 9 has FAdd which can be used for atomic floats. - if (value_type.is_float()) { - builder->CreateAtomicRMW(AtomicRMWInst::FAdd, ptr, val, AtomicOrdering::Monotonic); - } else { - builder->CreateAtomicRMW(AtomicRMWInst::Add, ptr, val, AtomicOrdering::Monotonic); - } -#endif } else { Value *index = codegen(op->index); // Scalarize vector store. @@ -2489,19 +2467,11 @@ void CodeGen_LLVM::codegen_atomic_rmw(const Store *op) { Value *idx = builder->CreateExtractElement(index, lane); Value *v = builder->CreateExtractElement(val, lane); Value *ptr = codegen_buffer_pointer(op->name, value_type.element_of(), idx); -#if LLVM_VERSION >= 130 if (value_type.is_float()) { builder->CreateAtomicRMW(AtomicRMWInst::FAdd, ptr, v, llvm::MaybeAlign(), AtomicOrdering::Monotonic); } else { builder->CreateAtomicRMW(AtomicRMWInst::Add, ptr, v, llvm::MaybeAlign(), AtomicOrdering::Monotonic); } -#else - if (value_type.is_float()) { - builder->CreateAtomicRMW(AtomicRMWInst::FAdd, ptr, v, AtomicOrdering::Monotonic); - } else { - builder->CreateAtomicRMW(AtomicRMWInst::Add, ptr, v, AtomicOrdering::Monotonic); - } -#endif } } } else { @@ -2564,13 +2534,8 @@ void CodeGen_LLVM::codegen_atomic_rmw(const Store *op) { val = builder->CreateBitCast(val, int_type); cmp_val = builder->CreateBitCast(cmp_val, int_type); } -#if LLVM_VERSION >= 130 Value *cmpxchg_pair = builder->CreateAtomicCmpXchg( ptr, cmp_val, val, llvm::MaybeAlign(), AtomicOrdering::Monotonic, AtomicOrdering::Monotonic); -#else - Value *cmpxchg_pair = builder->CreateAtomicCmpXchg( - ptr, cmp_val, val, AtomicOrdering::Monotonic, AtomicOrdering::Monotonic); -#endif Value *val_loaded = builder->CreateExtractValue(cmpxchg_pair, 0, "val_loaded"); Value *success = builder->CreateExtractValue(cmpxchg_pair, 1, "success"); if (need_bit_cast) { @@ -2636,61 +2601,6 @@ void CodeGen_LLVM::visit(const Call *op) { internal_assert(op->args.size() == 1); Value *a = codegen(op->args[0]); value = builder->CreateNot(a); - } else if (op->is_intrinsic(Call::reinterpret)) { - internal_assert(op->args.size() == 1); - Type dst = op->type; - Type src = op->args[0].type(); - llvm::Type *llvm_dst = llvm_type_of(dst); - value = codegen(op->args[0]); - if (src.is_handle() && !dst.is_handle()) { - internal_assert(dst.is_uint() && dst.bits() == 64); - - // Handle -> UInt64 - llvm::DataLayout d(module.get()); - if (d.getPointerSize() == 4) { - llvm::Type *intermediate = llvm_type_of(UInt(32, dst.lanes())); - value = builder->CreatePtrToInt(value, intermediate); - value = builder->CreateZExt(value, llvm_dst); - } else if (d.getPointerSize() == 8) { - value = builder->CreatePtrToInt(value, llvm_dst); - } else { - internal_error << "Pointer size is neither 4 nor 8 bytes\n"; - } - - } else if (dst.is_handle() && !src.is_handle()) { - internal_assert(src.is_uint() && src.bits() == 64); - - // UInt64 -> Handle - llvm::DataLayout d(module.get()); - if (d.getPointerSize() == 4) { - llvm::Type *intermediate = llvm_type_of(UInt(32, src.lanes())); - value = builder->CreateTrunc(value, intermediate); - value = builder->CreateIntToPtr(value, llvm_dst); - } else if (d.getPointerSize() == 8) { - value = builder->CreateIntToPtr(value, llvm_dst); - } else { - internal_error << "Pointer size is neither 4 nor 8 bytes\n"; - } - - } else { - if (src.is_scalar() && dst.is_vector()) { - // If the source type is a scalar, we promote it to an - // equivalent vector of width one before doing the - // bitcast, because llvm's bitcast operator doesn't - // want to convert between scalars and vectors. - value = create_broadcast(value, 1); - } - if (src.is_vector() && dst.is_scalar()) { - // Similarly, if we're converting from a vector to a - // scalar, convert to a vector of width 1 first, and - // then extract the first lane. - llvm_dst = get_vector_type(llvm_dst, 1); - } - value = builder->CreateBitCast(value, llvm_dst); - if (src.is_vector() && dst.is_scalar()) { - value = builder->CreateExtractElement(value, (uint64_t)0); - } - } } else if (op->is_intrinsic(Call::shift_left)) { internal_assert(op->args.size() == 2); if (op->args[1].type().is_uint()) { @@ -2787,8 +2697,8 @@ void CodeGen_LLVM::visit(const Call *op) { llvm::Function *fn = llvm::Intrinsic::getDeclaration(module.get(), (op->is_intrinsic(Call::count_leading_zeros)) ? llvm::Intrinsic::ctlz : llvm::Intrinsic::cttz, arg_type); - llvm::Value *is_const_zero_undef = llvm::ConstantInt::getFalse(*context); - llvm::Value *args[2] = {codegen(op->args[0]), is_const_zero_undef}; + llvm::Value *is_const_zero_poison = llvm::ConstantInt::getFalse(*context); + llvm::Value *args[2] = {codegen(op->args[0]), is_const_zero_poison}; CallInst *call = builder->CreateCall(fn, args); value = call; } else if (op->is_intrinsic(Call::return_second)) { @@ -2922,7 +2832,7 @@ void CodeGen_LLVM::visit(const Call *op) { internal_assert(op->args.size() == 2); // Try to fold the vector reduce for a call to saturating_add - const bool folded = try_to_fold_vector_reduce(op->args[0], op->args[1]); + const bool folded = op->is_intrinsic(Call::saturating_add) && try_to_fold_vector_reduce(op->args[0], op->args[1]); if (!folded) { std::string intrin; @@ -3269,7 +3179,7 @@ void CodeGen_LLVM::visit(const Call *op) { " integer overflow for int32 and int64 is undefined behavior in" " Halide.\n"; } else if (op->is_intrinsic(Call::undef)) { - value = UndefValue::get(llvm_type_of(op->type)); + user_error << "undef not eliminated before code generation. Please report this as a Halide bug.\n"; } else if (op->is_intrinsic(Call::size_of_halide_buffer_t)) { llvm::DataLayout d(module.get()); value = ConstantInt::get(i32_t, (int)d.getTypeAllocSize(halide_buffer_t_type)); @@ -3284,6 +3194,10 @@ void CodeGen_LLVM::visit(const Call *op) { value = codegen(lower_float16_transcendental_to_float32_equivalent(op)); } else if (op->is_intrinsic(Call::mux)) { value = codegen(lower_mux(op)); + } else if (op->is_intrinsic(Call::extract_bits)) { + value = codegen(lower_extract_bits(op)); + } else if (op->is_intrinsic(Call::concat_bits)) { + value = codegen(lower_concat_bits(op)); } else if (op->is_intrinsic()) { Expr lowered = lower_intrinsic(op); if (!lowered.defined()) { @@ -3467,7 +3381,6 @@ void CodeGen_LLVM::visit(const Call *op) { if (op->is_pure()) { call->setDoesNotAccessMemory(); } - call->setDoesNotThrow(); value = call; } else { @@ -3485,7 +3398,7 @@ void CodeGen_LLVM::visit(const Call *op) { // No vector version found. Scalarize. Extract each simd // lane in turn and do one scalar call to the function. - value = UndefValue::get(result_type); + value = PoisonValue::get(result_type); for (int i = 0; i < op->type.lanes(); i++) { Value *idx = ConstantInt::get(i32_t, i); vector arg_lane(args.size()); @@ -3500,7 +3413,6 @@ void CodeGen_LLVM::visit(const Call *op) { if (op->is_pure()) { call->setDoesNotAccessMemory(); } - call->setDoesNotThrow(); if (!call->getType()->isVoidTy()) { value = builder->CreateInsertElement(value, call, idx); } // otherwise leave it as undef. @@ -4560,8 +4472,9 @@ Value *CodeGen_LLVM::call_intrin(const Type &result_type, int intrin_lanes, intrin, arg_values); } -Value *CodeGen_LLVM::call_intrin(llvm::Type *result_type, int intrin_lanes, - const string &name, vector arg_values) { +Value *CodeGen_LLVM::call_intrin(const llvm::Type *result_type, int intrin_lanes, + const string &name, vector arg_values, + bool scalable_vector_result) { llvm::Function *fn = module->getFunction(name); if (!fn) { vector arg_types(arg_values.size()); @@ -4571,7 +4484,13 @@ Value *CodeGen_LLVM::call_intrin(llvm::Type *result_type, int intrin_lanes, llvm::Type *intrinsic_result_type = result_type->getScalarType(); if (intrin_lanes > 1) { - intrinsic_result_type = get_vector_type(result_type->getScalarType(), intrin_lanes); + if (scalable_vector_result && effective_vscale != 0) { + intrinsic_result_type = get_vector_type(result_type->getScalarType(), + intrin_lanes / effective_vscale, true); + } else { + intrinsic_result_type = get_vector_type(result_type->getScalarType(), + intrin_lanes); + } } FunctionType *func_t = FunctionType::get(intrinsic_result_type, arg_types, false); fn = llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, name, module.get()); @@ -4581,7 +4500,7 @@ Value *CodeGen_LLVM::call_intrin(llvm::Type *result_type, int intrin_lanes, return call_intrin(result_type, intrin_lanes, fn, arg_values); } -Value *CodeGen_LLVM::call_intrin(llvm::Type *result_type, int intrin_lanes, +Value *CodeGen_LLVM::call_intrin(const llvm::Type *result_type, int intrin_lanes, llvm::Function *intrin, vector arg_values) { internal_assert(intrin); int arg_lanes = 1; @@ -4635,6 +4554,9 @@ Value *CodeGen_LLVM::call_intrin(llvm::Type *result_type, int intrin_lanes, llvm::FunctionType *intrin_type = intrin->getFunctionType(); for (int i = 0; i < (int)arg_values.size(); i++) { + if (arg_values[i]->getType() != intrin_type->getParamType(i)) { + arg_values[i] = normalize_fixed_scalable_vector_type(intrin_type->getParamType(i), arg_values[i]); + } if (arg_values[i]->getType() != intrin_type->getParamType(i)) { // There can be some mismatches in types, such as when passing scalar Halide type T // to LLVM vector type <1 x T>. @@ -4750,14 +4672,20 @@ Value *CodeGen_LLVM::shuffle_vectors(Value *a, Value *b, } else { // Only let -1 be undef. internal_assert(indices[i] == -1); - llvm_indices[i] = UndefValue::get(i32_t); + llvm_indices[i] = PoisonValue::get(i32_t); } } + if (isa(a->getType())) { + a = scalable_to_fixed_vector_type(a); + } + if (isa(b->getType())) { + b = scalable_to_fixed_vector_type(b); + } return builder->CreateShuffleVector(a, b, ConstantVector::get(llvm_indices)); } Value *CodeGen_LLVM::shuffle_vectors(Value *a, const std::vector &indices) { - Value *b = UndefValue::get(a->getType()); + Value *b = PoisonValue::get(a->getType()); return shuffle_vectors(a, b, indices); } @@ -4810,5 +4738,106 @@ bool CodeGen_LLVM::supports_call_as_float16(const Call *op) const { return false; } +llvm::Value *CodeGen_LLVM::normalize_fixed_scalable_vector_type(llvm::Type *desired_type, llvm::Value *result) { + llvm::Type *actual_type = result->getType(); + + if (isa(actual_type) && + isa(desired_type)) { + const llvm::FixedVectorType *fixed = cast(actual_type); + const llvm::ScalableVectorType *scalable = cast(desired_type); + if (fixed->getElementType() == scalable->getElementType()) { + return fixed_to_scalable_vector_type(result); + } + } else if (isa(desired_type) && + isa(actual_type)) { + const llvm::ScalableVectorType *scalable = cast(actual_type); + const llvm::FixedVectorType *fixed = cast(desired_type); + if (fixed->getElementType() == scalable->getElementType()) { + return scalable_to_fixed_vector_type(result); + } + } + + return result; +} + +llvm::Value *CodeGen_LLVM::fixed_to_scalable_vector_type(llvm::Value *fixed_arg) { + internal_assert(effective_vscale != 0); + internal_assert(isa(fixed_arg->getType())); + const llvm::FixedVectorType *fixed = cast(fixed_arg->getType()); + internal_assert(fixed != nullptr); + auto lanes = fixed->getNumElements(); + + const llvm::ScalableVectorType *scalable = cast(get_vector_type(fixed->getElementType(), + lanes / effective_vscale, true)); + internal_assert(fixed != nullptr); + + internal_assert(fixed->getElementType() == scalable->getElementType()); + internal_assert(lanes == (scalable->getMinNumElements() * effective_vscale)); + + // E.g. llvm.experimental.vector.insert.nxv2i64.v4i64(, <4 x i64>, i64) + const char *type_designator; + if (fixed->getElementType()->isIntegerTy()) { + type_designator = "i"; + } else { + type_designator = "f"; + } + std::string intrin = "llvm.experimental.vector.insert.nxv" + std::to_string(scalable->getMinNumElements()); + intrin += type_designator; + std::string bits_designator = std::to_string(fixed->getScalarSizeInBits()); + intrin += bits_designator; + intrin += ".v" + std::to_string(lanes) + type_designator + bits_designator; + Constant *poison = PoisonValue::get(scalable->getElementType()); + llvm::Value *result_vec = ConstantVector::getSplat(scalable->getElementCount(), poison); + + std::vector args; + args.push_back(result_vec); + args.push_back(value); + args.push_back(ConstantInt::get(i64_t, 0)); + return call_intrin(scalable, lanes, intrin, args, true); +} + +llvm::Value *CodeGen_LLVM::scalable_to_fixed_vector_type(llvm::Value *scalable_arg) { + internal_assert(effective_vscale != 0); + internal_assert(isa(scalable_arg->getType())); + const llvm::ScalableVectorType *scalable = cast(scalable_arg->getType()); + internal_assert(scalable != nullptr); + + const llvm::FixedVectorType *fixed = cast(get_vector_type(scalable->getElementType(), + scalable->getMinNumElements() * effective_vscale, false)); + internal_assert(fixed != nullptr); + + internal_assert(fixed->getElementType() == scalable->getElementType()); + internal_assert(fixed->getNumElements() == (scalable->getMinNumElements() * effective_vscale)); + + // E.g. <64 x i8> @llvm.experimental.vector.extract.v64i8.nxv8i8( %vresult, i64 0) + const char *type_designator; + if (scalable->getElementType()->isIntegerTy()) { + type_designator = "i"; + } else { + type_designator = "f"; + } + std::string bits_designator = std::to_string(fixed->getScalarSizeInBits()); + std::string intrin = "llvm.experimental.vector.extract.v" + std::to_string(fixed->getNumElements()) + type_designator + bits_designator; + intrin += ".nxv" + std::to_string(scalable->getMinNumElements()) + type_designator + bits_designator; + std::vector args; + args.push_back(scalable_arg); + args.push_back(ConstantInt::get(i64_t, 0)); + + return call_intrin(fixed, fixed->getNumElements(), intrin, args, false); +} + +int CodeGen_LLVM::get_vector_num_elements(const llvm::Type *t) { + if (isa(t)) { + const auto *vt = cast(t); + return vt->getNumElements(); + } else if (isa(t)) { + internal_assert(effective_vscale != 0) << "Scalable vector type enountered without vector_bits being set.\n"; + const auto *vt = cast(t); + return vt->getMinNumElements() * effective_vscale; + } else { + return 1; + } +} + } // namespace Internal } // namespace Halide diff --git a/src/CodeGen_LLVM.h b/src/CodeGen_LLVM.h index 2811c1be9c31..b44a7bafd3be 100644 --- a/src/CodeGen_LLVM.h +++ b/src/CodeGen_LLVM.h @@ -106,11 +106,21 @@ class CodeGen_LLVM : public IRVisitor { virtual void end_func(const std::vector &args); // @} - /** What should be passed as -mcpu, -mattrs, and related for - * compilation. The architecture-specific code generator should - * define these. */ + /** What should be passed as -mcpu (warning: implies attrs!), -mattrs, + * and related for compilation. The architecture-specific code generator + * should define these. + * + * `mcpu_target()` - target this specific CPU, in the sense of the allowed + * ISA sets *and* the CPU-specific tuning/assembly instruction scheduling. + * + * `mcpu_tune()` - expect that we will be running on this specific CPU, + * so perform CPU-specific tuning/assembly instruction scheduling, *but* + * DON'T sacrifice the portability, support running on other CPUs, only + * make use of the ISAs that are enabled by `mcpu_target()`+`mattrs()`. + */ // @{ - virtual std::string mcpu() const = 0; + virtual std::string mcpu_target() const = 0; + virtual std::string mcpu_tune() const = 0; virtual std::string mattrs() const = 0; virtual std::string mabi() const; virtual bool use_soft_float_abi() const = 0; @@ -126,6 +136,14 @@ class CodeGen_LLVM : public IRVisitor { /** What's the natural vector bit-width to use for loads, stores, etc. */ virtual int native_vector_bits() const = 0; + /** For architectures that have vscale vectors, return the constant vscale to use. + * Default of 0 means do not use vscale vectors. Generally will depend on + * the target flags and vector_bits settings. + */ + virtual int target_vscale() const { + return 0; + } + /** Return the type in which arithmetic should be done for the * given storage type. */ virtual Type upgrade_type_for_arithmetic(const Type &) const; @@ -159,8 +177,10 @@ class CodeGen_LLVM : public IRVisitor { * multiple related modules (e.g. multiple device kernels). */ virtual void init_module(); +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE /** Add external_code entries to llvm module. */ void add_external_code(const Module &halide_module); +#endif /** Run all of llvm's optimization passes on the module. */ void optimize_module(); @@ -197,9 +217,7 @@ class CodeGen_LLVM : public IRVisitor { *scalar_value_t_type, *device_interface_t_type, *pseudostack_slot_t_type, - *semaphore_t_type, - *semaphore_acquire_t_type, - *parallel_task_t_type; + *semaphore_t_type; // @} @@ -311,6 +329,7 @@ class CodeGen_LLVM : public IRVisitor { void visit(const FloatImm *) override; void visit(const StringImm *) override; void visit(const Cast *) override; + void visit(const Reinterpret *) override; void visit(const Variable *) override; void visit(const Add *) override; void visit(const Sub *) override; @@ -440,9 +459,10 @@ class CodeGen_LLVM : public IRVisitor { const std::string &name, std::vector); llvm::Value *call_intrin(const Type &t, int intrin_lanes, llvm::Function *intrin, std::vector); - llvm::Value *call_intrin(llvm::Type *t, int intrin_lanes, - const std::string &name, std::vector); - llvm::Value *call_intrin(llvm::Type *t, int intrin_lanes, + llvm::Value *call_intrin(const llvm::Type *t, int intrin_lanes, + const std::string &name, std::vector, + bool scalable_vector_result = false); + llvm::Value *call_intrin(const llvm::Type *t, int intrin_lanes, llvm::Function *intrin, std::vector); // @} @@ -456,7 +476,7 @@ class CodeGen_LLVM : public IRVisitor { /** Create an LLVM shuffle vectors instruction. */ virtual llvm::Value *shuffle_vectors(llvm::Value *a, llvm::Value *b, const std::vector &indices); - /** Shorthand for shuffling a vector with an undef vector. */ + /** Shorthand for shuffling a single vector. */ llvm::Value *shuffle_vectors(llvm::Value *v, const std::vector &indices); /** Go looking for a vector version of a runtime function. Will @@ -497,6 +517,19 @@ class CodeGen_LLVM : public IRVisitor { This is used to avoid "emulated" equivalent code-gen in case target has FP16 feature **/ virtual bool supports_call_as_float16(const Call *op) const; + /** Ensure that a vector value is either fixed or vscale depending to match desired_type. + */ + llvm::Value *normalize_fixed_scalable_vector_type(llvm::Type *desired_type, llvm::Value *result); + + /** Convert an LLVM fixed vector value to the corresponding vscale vector value. */ + llvm::Value *fixed_to_scalable_vector_type(llvm::Value *fixed); + + /** Convert an LLVM vscale vector value to the corresponding fixed vector value. */ + llvm::Value *scalable_to_fixed_vector_type(llvm::Value *scalable); + + /** Get number of vector elements, taking into account scalable vectors. Returns 1 for scalars. */ + int get_vector_num_elements(const llvm::Type *t); + private: /** All the values in scope at the current code location during * codegen. Use sym_push and sym_pop to access. */ @@ -517,6 +550,11 @@ class CodeGen_LLVM : public IRVisitor { /** Use the LLVM large code model when this is set. */ bool llvm_large_code_model; + /** Cache the result of target_vscale from architecture specific implementation + * as this is used on every Halide to LLVM type conversion. + */ + int effective_vscale; + /** Embed an instance of halide_filter_metadata_t in the code, using * the given name (by convention, this should be ${FUNCTIONNAME}_metadata) * as extern "C" linkage. Note that the return value is a function-returning- @@ -524,7 +562,7 @@ class CodeGen_LLVM : public IRVisitor { */ llvm::Function *embed_metadata_getter(const std::string &metadata_getter_name, const std::string &function_name, const std::vector &args, - const std::map &metadata_name_map); + const MetadataNameMap &metadata_name_map); /** Embed a constant expression as a global variable. */ llvm::Constant *embed_constant_expr(Expr e, llvm::Type *t); diff --git a/src/CodeGen_MIPS.cpp b/src/CodeGen_MIPS.cpp index 4118a12b684f..26bd3a502146 100644 --- a/src/CodeGen_MIPS.cpp +++ b/src/CodeGen_MIPS.cpp @@ -19,7 +19,8 @@ class CodeGen_MIPS : public CodeGen_Posix { protected: using CodeGen_Posix::visit; - string mcpu() const override; + string mcpu_target() const override; + string mcpu_tune() const override; string mattrs() const override; bool use_soft_float_abi() const override; int native_vector_bits() const override; @@ -29,7 +30,7 @@ CodeGen_MIPS::CodeGen_MIPS(const Target &t) : CodeGen_Posix(t) { } -string CodeGen_MIPS::mcpu() const { +string CodeGen_MIPS::mcpu_target() const { if (target.bits == 32) { return ""; } else { @@ -37,6 +38,10 @@ string CodeGen_MIPS::mcpu() const { } } +string CodeGen_MIPS::mcpu_tune() const { + return mcpu_target(); +} + string CodeGen_MIPS::mattrs() const { if (target.bits == 32) { return ""; diff --git a/src/CodeGen_Metal_Dev.cpp b/src/CodeGen_Metal_Dev.cpp index 6089d84b9d3a..c63f23a1b79f 100644 --- a/src/CodeGen_Metal_Dev.cpp +++ b/src/CodeGen_Metal_Dev.cpp @@ -2,7 +2,6 @@ #include #include -#include "CodeGen_C.h" #include "CodeGen_GPU_Dev.h" #include "CodeGen_Internal.h" #include "CodeGen_Metal_Dev.h" @@ -50,17 +49,17 @@ class CodeGen_Metal_Dev : public CodeGen_GPU_Dev { } protected: - class CodeGen_Metal_C : public CodeGen_C { + class CodeGen_Metal_C : public CodeGen_GPU_C { public: CodeGen_Metal_C(std::ostream &s, const Target &t) - : CodeGen_C(s, t) { + : CodeGen_GPU_C(s, t) { } void add_kernel(const Stmt &stmt, const std::string &name, const std::vector &args); protected: - using CodeGen_C::visit; + using CodeGen_GPU_C::visit; std::string print_type(Type type, AppendSpaceIfNeeded space_option = DoNotAppendSpace) override; // Vectors in Metal come in two varieties, regular and packed. // For storage allocations and pointers used in address arithmetic, @@ -267,7 +266,7 @@ void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const For *loop) { } else { user_assert(loop->for_type != ForType::Parallel) << "Cannot use parallel loops inside Metal kernel\n"; - CodeGen_C::visit(loop); + CodeGen_GPU_C::visit(loop); } } @@ -321,7 +320,7 @@ void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const Call *op) { stream << ");\n"; print_assignment(op->type, "0"); } else { - CodeGen_C::visit(op); + CodeGen_GPU_C::visit(op); } } @@ -789,7 +788,7 @@ void CodeGen_Metal_Dev::init_module() { << "#endif\n" << "}\n"; // close namespace - src_stream << "#define halide_unused(x) (void)(x)\n"; + src_stream << "#define halide_maybe_unused(x) (void)(x)\n"; src_stream << "\n"; diff --git a/src/CodeGen_OpenCL_Dev.cpp b/src/CodeGen_OpenCL_Dev.cpp index 539a0699b909..3302b9413076 100644 --- a/src/CodeGen_OpenCL_Dev.cpp +++ b/src/CodeGen_OpenCL_Dev.cpp @@ -4,7 +4,6 @@ #include #include "CSE.h" -#include "CodeGen_C.h" #include "CodeGen_GPU_Dev.h" #include "CodeGen_Internal.h" #include "CodeGen_OpenCL_Dev.h" @@ -55,18 +54,19 @@ class CodeGen_OpenCL_Dev : public CodeGen_GPU_Dev { } protected: - class CodeGen_OpenCL_C : public CodeGen_C { + class CodeGen_OpenCL_C : public CodeGen_GPU_C { public: CodeGen_OpenCL_C(std::ostream &s, Target t) - : CodeGen_C(s, t) { + : CodeGen_GPU_C(s, t) { integer_suffix_style = IntegerSuffixStyle::OpenCL; + vector_declaration_style = VectorDeclarationStyle::OpenCLSyntax; } void add_kernel(Stmt stmt, const std::string &name, const std::vector &args); protected: - using CodeGen_C::visit; + using CodeGen_GPU_C::visit; std::string print_type(Type type, AppendSpaceIfNeeded append_space = DoNotAppendSpace) override; std::string print_reinterpret(Type type, const Expr &e) override; std::string print_extern_call(const Call *op) override; @@ -223,7 +223,7 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const For *loop) { } else { user_assert(loop->for_type != ForType::Parallel) << "Cannot use parallel loops inside OpenCL kernel\n"; - CodeGen_C::visit(loop); + CodeGen_GPU_C::visit(loop); } } @@ -351,7 +351,7 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Call *op) { print_assignment(op->type, a0 + " >> " + a1); } } else { - CodeGen_C::visit(op); + CodeGen_GPU_C::visit(op); } } else if (op->is_intrinsic(Call::image_load)) { // image_load(, , , , , @@ -455,7 +455,7 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Call *op) { stream << write_image.str(); } } else { - CodeGen_C::visit(op); + CodeGen_GPU_C::visit(op); } } @@ -743,7 +743,7 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Cast *op) { if (op->type.is_vector()) { print_assignment(op->type, "convert_" + print_type(op->type) + "(" + print_expr(op->value) + ")"); } else { - CodeGen_C::visit(op); + CodeGen_GPU_C::visit(op); } } @@ -755,7 +755,7 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Select *op) { equiv.accept(this); return; } - CodeGen_C::visit(op); + CodeGen_GPU_C::visit(op); } void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Allocate *op) { @@ -858,8 +858,14 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Shuffle *op) { } stream << ");\n"; } + } else if (op->is_extract_element()) { + // OpenCL requires using .s format for extracting an element + ostringstream rhs; + rhs << print_expr(op->vectors[0]); + rhs << ".s" << op->indices[0]; + print_assignment(op->type, rhs.str()); } else { - internal_error << "Shuffle not implemented.\n"; + CodeGen_GPU_C::visit(op); } } @@ -879,7 +885,7 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Atomic *op) { // Issue atomic stores. ScopedValue old_emit_atomic_stores(emit_atomic_stores, true); - CodeGen_C::visit(op); + CodeGen_GPU_C::visit(op); } void CodeGen_OpenCL_Dev::add_kernel(Stmt s, @@ -926,6 +932,13 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::add_kernel(Stmt s, debug(2) << "After eliminating bool vectors:\n" << s << "\n"; + // We need to scalarize/de-predicate any loads/stores, since OpenCL does not + // support predication. + s = scalarize_predicated_loads_stores(s); + + debug(2) << "After removing predication: \n" + << s; + // Figure out which arguments should be passed in __constant. // Such arguments should be: // - not written to, @@ -1136,7 +1149,7 @@ void CodeGen_OpenCL_Dev::init_module() { // There does not appear to be a reliable way to safely ignore unused // variables in OpenCL C. See https://github.com/halide/Halide/issues/4918. - src_stream << "#define halide_unused(x)\n"; + src_stream << "#define halide_maybe_unused(x)\n"; if (target.has_feature(Target::CLDoubles)) { src_stream << "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n" diff --git a/src/CodeGen_OpenGLCompute_Dev.cpp b/src/CodeGen_OpenGLCompute_Dev.cpp index 9eec624c7e6c..3dc4ea6604fa 100644 --- a/src/CodeGen_OpenGLCompute_Dev.cpp +++ b/src/CodeGen_OpenGLCompute_Dev.cpp @@ -870,7 +870,7 @@ void CodeGen_OpenGLCompute_C::add_kernel(const Stmt &s, stream << "#version 430\n"; } stream << "float float_from_bits(int x) { return intBitsToFloat(int(x)); }\n"; - stream << "#define halide_unused(x) (void)(x)\n"; + stream << "#define halide_maybe_unused(x) (void)(x)\n"; for (size_t i = 0; i < args.size(); i++) { if (args[i].is_buffer) { diff --git a/src/CodeGen_PTX_Dev.cpp b/src/CodeGen_PTX_Dev.cpp index 1baa55bbd1d0..711040f54afd 100644 --- a/src/CodeGen_PTX_Dev.cpp +++ b/src/CodeGen_PTX_Dev.cpp @@ -91,7 +91,8 @@ class CodeGen_PTX_Dev : public CodeGen_LLVM, public CodeGen_GPU_Dev { // @} std::string march() const; - std::string mcpu() const override; + std::string mcpu_target() const override; + std::string mcpu_tune() const override; std::string mattrs() const override; bool use_soft_float_abi() const override; int native_vector_bits() const override; @@ -153,7 +154,7 @@ void CodeGen_PTX_Dev::add_kernel(Stmt stmt, // Make our function FunctionType *func_t = FunctionType::get(void_t, arg_types, false); function = llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, name, module.get()); - set_function_attributes_for_target(function, target); + set_function_attributes_from_halide_target_options(*function); // Mark the buffer args as no alias for (size_t i = 0; i < args.size(); i++) { @@ -542,9 +543,8 @@ string CodeGen_PTX_Dev::march() const { return "nvptx64"; } -string CodeGen_PTX_Dev::mcpu() const { +string CodeGen_PTX_Dev::mcpu_target() const { if (target.has_feature(Target::CUDACapability86)) { - user_assert(LLVM_VERSION >= 130) << "The linked LLVM version does not support cuda compute capability 8.6\n"; return "sm_86"; } else if (target.has_feature(Target::CUDACapability80)) { return "sm_80"; @@ -567,6 +567,10 @@ string CodeGen_PTX_Dev::mcpu() const { } } +string CodeGen_PTX_Dev::mcpu_tune() const { + return mcpu_target(); +} + string CodeGen_PTX_Dev::mattrs() const { if (target.has_feature(Target::CUDACapability86)) { return "+ptx71"; @@ -615,15 +619,10 @@ vector CodeGen_PTX_Dev::compile_to_src() { options.HonorSignDependentRoundingFPMathOption = false; options.NoZerosInBSS = false; options.GuaranteedTailCallOpt = false; -#if LLVM_VERSION >= 130 - // nothing -#else - options.StackAlignmentOverride = 0; -#endif std::unique_ptr target_machine(llvm_target->createTargetMachine(triple.str(), - mcpu(), mattrs(), options, + mcpu_target(), mattrs(), options, llvm::Reloc::PIC_, llvm::CodeModel::Small, CodeGenOpt::Aggressive)); @@ -637,22 +636,6 @@ vector CodeGen_PTX_Dev::compile_to_src() { raw_svector_ostream ostream(outstr); ostream.SetUnbuffered(); - // NOTE: use of the "legacy" PassManager here is still required; it is deprecated - // for optimization, but is still the only complete API for codegen as of work-in-progress - // LLVM14. At the time of this comment (Dec 2021), there is no firm plan as to when codegen will - // be fully available in the new PassManager, so don't worry about this 'legacy' - // tag until there's any indication that the old APIs start breaking. - // - // See: - // https://lists.llvm.org/pipermail/llvm-dev/2021-April/150100.html - // https://releases.llvm.org/13.0.0/docs/ReleaseNotes.html#changes-to-the-llvm-ir - // https://groups.google.com/g/llvm-dev/c/HoS07gXx0p8 - legacy::FunctionPassManager function_pass_manager(module.get()); - legacy::PassManager module_pass_manager; - - module_pass_manager.add(createTargetTransformInfoWrapperPass(target_machine->getTargetIRAnalysis())); - function_pass_manager.add(createTargetTransformInfoWrapperPass(target_machine->getTargetIRAnalysis())); - // NVidia's libdevice library uses a __nvvm_reflect to choose // how to handle denormalized numbers. (The pass replaces calls // to __nvvm_reflect with a constant via a map lookup. The inliner @@ -679,26 +662,72 @@ vector CodeGen_PTX_Dev::compile_to_src() { } } - // At present, we default to *enabling* LLVM loop optimization, - // unless DisableLLVMLoopOpt is set; we're going to flip this to defaulting - // to *not* enabling these optimizations (and removing the DisableLLVMLoopOpt feature). - // See https://github.com/halide/Halide/issues/4113 for more info. - // (Note that setting EnableLLVMLoopOpt always enables loop opt, regardless - // of the setting of DisableLLVMLoopOpt.) - const bool do_loop_opt = !target.has_feature(Target::DisableLLVMLoopOpt) || - target.has_feature(Target::EnableLLVMLoopOpt); + // halide_target_feature_disable_llvm_loop_opt is deprecated in Halide 15 + // (and will be removed in Halide 16). Halide 15 now defaults to disabling + // LLVM loop optimization, unless halide_target_feature_enable_llvm_loop_opt is set. + if (get_target().has_feature(Target::DisableLLVMLoopOpt)) { + user_warning << "halide_target_feature_disable_llvm_loop_opt is deprecated in Halide 15 " + "(and will be removed in Halide 16). Halide 15 now defaults to disabling " + "LLVM loop optimization, unless halide_target_feature_enable_llvm_loop_opt is set.\n"; + } + const bool do_loop_opt = get_target().has_feature(Target::EnableLLVMLoopOpt); + + // Define and run optimization pipeline with new pass manager + PipelineTuningOptions pto; + pto.LoopInterleaving = do_loop_opt; + pto.LoopVectorization = do_loop_opt; + pto.SLPVectorization = true; // Note: SLP vectorization has no analogue in the Halide scheduling model + pto.LoopUnrolling = do_loop_opt; + pto.ForgetAllSCEVInLoopUnroll = true; + + llvm::PassBuilder pb(target_machine.get(), pto); + + bool debug_pass_manager = false; + // These analysis managers have to be declared in this order. + llvm::LoopAnalysisManager lam; + llvm::FunctionAnalysisManager fam; + llvm::CGSCCAnalysisManager cgam; + llvm::ModuleAnalysisManager mam; + + // Register all the basic analyses with the managers. + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + ModulePassManager mpm; + +#if LLVM_VERSION >= 140 + using OptimizationLevel = llvm::OptimizationLevel; +#else + using OptimizationLevel = PassBuilder::OptimizationLevel; +#endif + + OptimizationLevel level = OptimizationLevel::O3; + + target_machine->registerPassBuilderCallbacks(pb); - PassManagerBuilder b; - b.OptLevel = 3; - b.Inliner = createFunctionInliningPass(b.OptLevel, 0, false); - b.LoopVectorize = do_loop_opt; - b.SLPVectorize = true; - b.DisableUnrollLoops = !do_loop_opt; + mpm = pb.buildPerModuleDefaultPipeline(level, debug_pass_manager); + mpm.run(*module, mam); + + if (llvm::verifyModule(*module, &errs())) { + report_fatal_error("Transformation resulted in an invalid module\n"); + } - target_machine->adjustPassManager(b); + // Optimization pipeline completed; run codegen pipeline - b.populateFunctionPassManager(function_pass_manager); - b.populateModulePassManager(module_pass_manager); + // NOTE: use of the "legacy" PassManager here is still required; it is deprecated + // for optimization, but is still the only complete API for codegen as of work-in-progress + // LLVM14. At the time of this comment (Dec 2021), there is no firm plan as to when codegen will + // be fully available in the new PassManager, so don't worry about this 'legacy' + // tag until there's any indication that the old APIs start breaking. + // + // See: + // https://lists.llvm.org/pipermail/llvm-dev/2021-April/150100.html + // https://releases.llvm.org/13.0.0/docs/ReleaseNotes.html#changes-to-the-llvm-ir + // https://groups.google.com/g/llvm-dev/c/HoS07gXx0p8 + legacy::PassManager module_pass_manager; + module_pass_manager.add(createTargetTransformInfoWrapperPass(target_machine->getTargetIRAnalysis())); // Override default to generate verbose assembly. target_machine->Options.MCOptions.AsmVerbose = true; @@ -709,18 +738,10 @@ vector CodeGen_PTX_Dev::compile_to_src() { bool fail = target_machine->addPassesToEmitFile(module_pass_manager, ostream, nullptr, ::llvm::CGFT_AssemblyFile, true); - if (fail) { - internal_error << "Failed to set up passes to emit PTX source\n"; - } - - // Run optimization passes - function_pass_manager.doInitialization(); - for (auto &function : *module) { - function_pass_manager.run(function); - } - function_pass_manager.doFinalization(); + internal_assert(!fail) << "Failed to set up passes to emit PTX source\n"; module_pass_manager.run(*module); + // Codegen pipeline completed. if (debug::debug_level() >= 2) { dump(); } @@ -742,7 +763,7 @@ vector CodeGen_PTX_Dev::compile_to_src() { f.write(buffer.data(), buffer.size()); f.close(); - string cmd = "ptxas --gpu-name " + mcpu() + " " + ptx.pathname() + " -o " + sass.pathname(); + string cmd = "ptxas --gpu-name " + mcpu_target() + " " + ptx.pathname() + " -o " + sass.pathname(); if (system(cmd.c_str()) == 0) { cmd = "nvdisasm " + sass.pathname(); int ret = system(cmd.c_str()); diff --git a/src/CodeGen_PowerPC.cpp b/src/CodeGen_PowerPC.cpp index 42dec77fd75d..7f1e7252e941 100644 --- a/src/CodeGen_PowerPC.cpp +++ b/src/CodeGen_PowerPC.cpp @@ -22,7 +22,8 @@ class CodeGen_PowerPC : public CodeGen_Posix { protected: void init_module() override; - string mcpu() const override; + string mcpu_target() const override; + string mcpu_tune() const override; string mattrs() const override; bool use_soft_float_abi() const override; int native_vector_bits() const override; @@ -141,7 +142,7 @@ void CodeGen_PowerPC::visit(const Max *op) { return CodeGen_Posix::visit(op); } -string CodeGen_PowerPC::mcpu() const { +string CodeGen_PowerPC::mcpu_target() const { if (target.bits == 32) { return "ppc32"; } else { @@ -155,6 +156,10 @@ string CodeGen_PowerPC::mcpu() const { } } +string CodeGen_PowerPC::mcpu_tune() const { + return mcpu_target(); +} + string CodeGen_PowerPC::mattrs() const { string features; string separator; diff --git a/src/CodeGen_RISCV.cpp b/src/CodeGen_RISCV.cpp index 01395f596b91..434105724c3a 100644 --- a/src/CodeGen_RISCV.cpp +++ b/src/CodeGen_RISCV.cpp @@ -19,7 +19,8 @@ class CodeGen_RISCV : public CodeGen_Posix { protected: using CodeGen_Posix::visit; - string mcpu() const override; + string mcpu_target() const override; + string mcpu_tune() const override; string mattrs() const override; string mabi() const override; bool use_soft_float_abi() const override; @@ -30,10 +31,14 @@ CodeGen_RISCV::CodeGen_RISCV(const Target &t) : CodeGen_Posix(t) { } -string CodeGen_RISCV::mcpu() const { +string CodeGen_RISCV::mcpu_target() const { return ""; } +string CodeGen_RISCV::mcpu_tune() const { + return mcpu_target(); +} + string CodeGen_RISCV::mattrs() const { // Note: the default march is "rv[32|64]imafdc", // which includes standard extensions: diff --git a/src/CodeGen_WebAssembly.cpp b/src/CodeGen_WebAssembly.cpp index 726330e47931..c9776ddc1e89 100644 --- a/src/CodeGen_WebAssembly.cpp +++ b/src/CodeGen_WebAssembly.cpp @@ -29,13 +29,15 @@ class CodeGen_WebAssembly : public CodeGen_Posix { void init_module() override; - string mcpu() const override; + string mcpu_target() const override; + string mcpu_tune() const override; string mattrs() const override; bool use_soft_float_abi() const override; int native_vector_bits() const override; bool use_pic() const override; void visit(const Cast *) override; + void visit(const Call *) override; void codegen_vector_reduce(const VectorReduce *, const Expr &) override; }; @@ -61,17 +63,10 @@ const WasmIntrinsic intrinsic_defs[] = { {"llvm.uadd.sat.v16i8", UInt(8, 16), "saturating_add", {UInt(8, 16), UInt(8, 16)}, Target::WasmSimd128}, // TODO: Are these really different than the standard llvm.*sub.sat.*? -#if LLVM_VERSION >= 130 {"llvm.wasm.sub.sat.signed.v16i8", Int(8, 16), "saturating_sub", {Int(8, 16), Int(8, 16)}, Target::WasmSimd128}, {"llvm.wasm.sub.sat.unsigned.v16i8", UInt(8, 16), "saturating_sub", {UInt(8, 16), UInt(8, 16)}, Target::WasmSimd128}, {"llvm.wasm.sub.sat.signed.v8i16", Int(16, 8), "saturating_sub", {Int(16, 8), Int(16, 8)}, Target::WasmSimd128}, {"llvm.wasm.sub.sat.unsigned.v8i16", UInt(16, 8), "saturating_sub", {UInt(16, 8), UInt(16, 8)}, Target::WasmSimd128}, -#else - {"llvm.wasm.sub.saturate.signed.v16i8", Int(8, 16), "saturating_sub", {Int(8, 16), Int(8, 16)}, Target::WasmSimd128}, - {"llvm.wasm.sub.saturate.unsigned.v16i8", UInt(8, 16), "saturating_sub", {UInt(8, 16), UInt(8, 16)}, Target::WasmSimd128}, - {"llvm.wasm.sub.saturate.signed.v8i16", Int(16, 8), "saturating_sub", {Int(16, 8), Int(16, 8)}, Target::WasmSimd128}, - {"llvm.wasm.sub.saturate.unsigned.v8i16", UInt(16, 8), "saturating_sub", {UInt(16, 8), UInt(16, 8)}, Target::WasmSimd128}, -#endif {"llvm.wasm.avgr.unsigned.v16i8", UInt(8, 16), "rounding_halving_add", {UInt(8, 16), UInt(8, 16)}, Target::WasmSimd128}, {"llvm.wasm.avgr.unsigned.v8i16", UInt(16, 8), "rounding_halving_add", {UInt(16, 8), UInt(16, 8)}, Target::WasmSimd128}, @@ -80,7 +75,6 @@ const WasmIntrinsic intrinsic_defs[] = { {"float_to_double", Float(64, 4), "float_to_double", {Float(32, 4)}, Target::WasmSimd128}, #endif -#if LLVM_VERSION >= 130 // With some work, some of these could possibly be adapted to work under earlier versions of LLVM. {"widening_mul_i8x16", Int(16, 16), "widening_mul", {Int(8, 16), Int(8, 16)}, Target::WasmSimd128}, {"widening_mul_i16x8", Int(32, 8), "widening_mul", {Int(16, 8), Int(16, 8)}, Target::WasmSimd128}, @@ -118,7 +112,6 @@ const WasmIntrinsic intrinsic_defs[] = { {"extend_u16x8_to_u32x8", UInt(32, 8), "widen_integer", {UInt(16, 8)}, Target::WasmSimd128}, {"extend_i32x4_to_i64x4", Int(64, 4), "widen_integer", {Int(32, 4)}, Target::WasmSimd128}, {"extend_u32x4_to_u64x4", UInt(64, 4), "widen_integer", {UInt(32, 4)}, Target::WasmSimd128}, -#endif }; // clang-format on @@ -147,7 +140,6 @@ void CodeGen_WebAssembly::init_module() { } void CodeGen_WebAssembly::visit(const Cast *op) { -#if LLVM_VERSION >= 130 struct Pattern { std::string intrin; ///< Name of the intrinsic Expr pattern; ///< The pattern to match against @@ -156,11 +148,6 @@ void CodeGen_WebAssembly::visit(const Cast *op) { // clang-format off static const Pattern patterns[] = { - {"q15mulr_sat_s", i16_sat(rounding_shift_right(widening_mul(wild_i16x_, wild_i16x_), u16(15))), Target::WasmSimd128}, - {"saturating_narrow", i8_sat(wild_i16x_), Target::WasmSimd128}, - {"saturating_narrow", u8_sat(wild_i16x_), Target::WasmSimd128}, - {"saturating_narrow", i16_sat(wild_i32x_), Target::WasmSimd128}, - {"saturating_narrow", u16_sat(wild_i32x_), Target::WasmSimd128}, {"int_to_double", f64(wild_i32x_), Target::WasmSimd128}, {"int_to_double", f64(wild_u32x_), Target::WasmSimd128}, #if LLVM_VERSION == 130 @@ -189,13 +176,46 @@ void CodeGen_WebAssembly::visit(const Cast *op) { } } } -#endif // LLVM_VERSION >= 130 + + CodeGen_Posix::visit(op); +} + +void CodeGen_WebAssembly::visit(const Call *op) { + struct Pattern { + std::string intrin; ///< Name of the intrinsic + Expr pattern; ///< The pattern to match against + Target::Feature required_feature; + }; + + // clang-format off + static const Pattern patterns[] = { + {"q15mulr_sat_s", i16_sat(rounding_shift_right(widening_mul(wild_i16x_, wild_i16x_), u16(15))), Target::WasmSimd128}, + {"saturating_narrow", i8_sat(wild_i16x_), Target::WasmSimd128}, + {"saturating_narrow", u8_sat(wild_i16x_), Target::WasmSimd128}, + {"saturating_narrow", i16_sat(wild_i32x_), Target::WasmSimd128}, + {"saturating_narrow", u16_sat(wild_i32x_), Target::WasmSimd128}, + }; + // clang-format on + + if (op->type.is_vector()) { + std::vector matches; + for (const Pattern &p : patterns) { + if (!target.has_feature(p.required_feature)) { + continue; + } + if (expr_match(p.pattern, op, matches)) { + value = call_overloaded_intrin(op->type, p.intrin, matches); + if (value) { + return; + } + } + } + } CodeGen_Posix::visit(op); } void CodeGen_WebAssembly::codegen_vector_reduce(const VectorReduce *op, const Expr &init) { -#if LLVM_VERSION >= 130 struct Pattern { VectorReduce::Operator reduce_op; int factor; @@ -264,15 +284,18 @@ void CodeGen_WebAssembly::codegen_vector_reduce(const VectorReduce *op, const Ex } } } -#endif // LLVM_VERSION >= 130 CodeGen_Posix::codegen_vector_reduce(op, init); } -string CodeGen_WebAssembly::mcpu() const { +string CodeGen_WebAssembly::mcpu_target() const { return ""; } +string CodeGen_WebAssembly::mcpu_tune() const { + return mcpu_target(); +} + string CodeGen_WebAssembly::mattrs() const { std::ostringstream s; string sep; diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index b75500ee2684..5d599409fb61 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -53,7 +53,8 @@ class CodeGen_X86 : public CodeGen_Posix { CodeGen_X86(Target); protected: - string mcpu() const override; + string mcpu_target() const override; + string mcpu_tune() const override; string mattrs() const override; bool use_soft_float_abi() const override; int native_vector_bits() const override; @@ -130,6 +131,11 @@ const x86Intrinsic intrinsic_defs[] = { {"llvm.ssub.sat.v16i16", Int(16, 16), "saturating_sub", {Int(16, 16), Int(16, 16)}, Target::AVX2}, {"llvm.ssub.sat.v8i16", Int(16, 8), "saturating_sub", {Int(16, 8), Int(16, 8)}}, + // Sum of absolute differences + {"llvm.x86.sse2.psad.bw", UInt(64, 2), "sum_of_absolute_differences", {UInt(8, 16), UInt(8, 16)}}, + {"llvm.x86.avx2.psad.bw", UInt(64, 4), "sum_of_absolute_differences", {UInt(8, 32), UInt(8, 32)}, Target::AVX2}, + {"llvm.x86.avx512.psad.bw.512", UInt(64, 8), "sum_of_absolute_differences", {UInt(8, 64), UInt(8, 64)}, Target::AVX512_Skylake}, + // Some of the instructions referred to below only appear with // AVX2, but LLVM generates better AVX code if you give it // full 256-bit vectors and let it do the slicing up into @@ -163,6 +169,10 @@ const x86Intrinsic intrinsic_defs[] = { {"packuswbx32", UInt(8, 32), "saturating_narrow", {Int(16, 32)}, Target::AVX2}, {"packuswbx16", UInt(8, 16), "saturating_narrow", {Int(16, 16)}}, + // Widening multiplies that use (v)pmaddwd + {"wmul_pmaddwd_avx2", Int(32, 8), "widening_mul", {Int(16, 8), Int(16, 8)}, Target::AVX2}, + {"wmul_pmaddwd_sse2", Int(32, 4), "widening_mul", {Int(16, 4), Int(16, 4)}}, + // Multiply keep high half {"llvm.x86.avx2.pmulh.w", Int(16, 16), "pmulh", {Int(16, 16), Int(16, 16)}, Target::AVX2}, {"llvm.x86.avx2.pmulhu.w", UInt(16, 16), "pmulh", {UInt(16, 16), UInt(16, 16)}, Target::AVX2}, @@ -184,6 +194,14 @@ const x86Intrinsic intrinsic_defs[] = { {"llvm.x86.avx2.pmadd.ub.sw", Int(16, 16), "saturating_dot_product", {UInt(8, 32), Int(8, 32)}, Target::AVX2}, {"llvm.x86.ssse3.pmadd.ub.sw.128", Int(16, 8), "saturating_dot_product", {UInt(8, 16), Int(8, 16)}, Target::SSE41}, + // Horizontal widening adds using 2-way dot products. + {"hadd_pmadd_u8_sse3", UInt(16, 8), "horizontal_widening_add", {UInt(8, 16)}, Target::SSE41}, + {"hadd_pmadd_u8_sse3", Int(16, 8), "horizontal_widening_add", {UInt(8, 16)}, Target::SSE41}, + {"hadd_pmadd_i8_sse3", Int(16, 8), "horizontal_widening_add", {Int(8, 16)}, Target::SSE41}, + {"hadd_pmadd_u8_avx2", UInt(16, 16), "horizontal_widening_add", {UInt(8, 32)}, Target::AVX2}, + {"hadd_pmadd_u8_avx2", Int(16, 16), "horizontal_widening_add", {UInt(8, 32)}, Target::AVX2}, + {"hadd_pmadd_i8_avx2", Int(16, 16), "horizontal_widening_add", {Int(8, 32)}, Target::AVX2}, + {"llvm.x86.avx512.pmaddw.d.512", Int(32, 16), "dot_product", {Int(16, 32), Int(16, 32)}, Target::AVX512_Skylake}, {"llvm.x86.avx512.pmaddw.d.512", Int(32, 16), "dot_product", {Int(16, 32), Int(16, 32)}, Target::AVX512_Cannonlake}, {"llvm.x86.avx2.pmadd.wd", Int(32, 8), "dot_product", {Int(16, 16), Int(16, 16)}, Target::AVX2}, @@ -455,11 +473,6 @@ void CodeGen_X86::visit(const Cast *op) { // saturate the result. {"pmulhrs", i16(rounding_shift_right(widening_mul(wild_i16x_, wild_i16x_), 15))}, - {"saturating_narrow", i16_sat(wild_i32x_)}, - {"saturating_narrow", u16_sat(wild_i32x_)}, - {"saturating_narrow", i8_sat(wild_i16x_)}, - {"saturating_narrow", u8_sat(wild_i16x_)}, - {"f32_to_bf16", bf16(wild_f32x_)}, }; // clang-format on @@ -518,6 +531,18 @@ void CodeGen_X86::visit(const Call *op) { return; } } + } else if (op->type.is_int() && + op->type.bits() <= 16 && + op->is_intrinsic(Call::rounding_halving_add)) { + // We can redirect signed rounding halving add to unsigned rounding + // halving add by adding 128 / 32768 to the result if the sign of the + // args differs. + internal_assert(op->args.size() == 2); + Type t = op->type.with_code(halide_type_uint); + Expr a = cast(t, op->args[0]); + Expr b = cast(t, op->args[1]); + codegen(cast(op->type, rounding_halving_add(a, b) + ((a ^ b) & (1 << (t.bits() - 1))))); + return; } else if (op->is_intrinsic(Call::absd)) { internal_assert(op->args.size() == 2); if (op->args[0].type().is_uint()) { @@ -545,6 +570,10 @@ void CodeGen_X86::visit(const Call *op) { {"pmulh", mul_shift_right(wild_i16x_, wild_i16x_, 16)}, {"pmulh", mul_shift_right(wild_u16x_, wild_u16x_, 16)}, {"saturating_pmulhrs", rounding_mul_shift_right(wild_i16x_, wild_i16x_, 15)}, + {"saturating_narrow", i16_sat(wild_i32x_)}, + {"saturating_narrow", u16_sat(wild_i32x_)}, + {"saturating_narrow", i8_sat(wild_i16x_)}, + {"saturating_narrow", u8_sat(wild_i16x_)}, }; // clang-format on @@ -578,6 +607,7 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init enum { CombineInit = 1 << 0, SwapOperands = 1 << 1, + SingleArg = 1 << 2, }; }; // clang-format off @@ -607,8 +637,15 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init {VectorReduce::Add, 2, wild_f32x_ * wild_f32x_, "dot_product", BFloat(16), Pattern::CombineInit}, // One could do a horizontal widening addition with - // dot_product against a vector of ones. Currently disabled - // because I haven't found case where it's clearly better. + // other dot_products against a vector of ones. Currently disabled + // because I haven't found other cases where it's clearly better. + {VectorReduce::Add, 2, u16(wild_u8x_), "horizontal_widening_add", {}, Pattern::SingleArg}, + {VectorReduce::Add, 2, i16(wild_u8x_), "horizontal_widening_add", {}, Pattern::SingleArg}, + {VectorReduce::Add, 2, i16(wild_i8x_), "horizontal_widening_add", {}, Pattern::SingleArg}, + + // Sum of absolute differences + {VectorReduce::Add, 8, u64(absd(wild_u8x_, wild_u8x_)), "sum_of_absolute_differences", {}}, + }; // clang-format on @@ -618,38 +655,93 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init continue; } if (expr_match(p.pattern, op->value, matches)) { - Expr a = matches[0]; - Expr b = matches[1]; - if (p.flags & Pattern::SwapOperands) { - std::swap(a, b); - } - if (p.narrow_type.bits() > 0) { - a = lossless_cast(p.narrow_type.with_lanes(a.type().lanes()), a); - b = lossless_cast(p.narrow_type.with_lanes(b.type().lanes()), b); - } - if (!a.defined() || !b.defined()) { - continue; - } + if (p.flags & Pattern::SingleArg) { + Expr a = matches[0]; - if (init.defined() && (p.flags & Pattern::CombineInit)) { - value = call_overloaded_intrin(op->type, p.intrin, {init, a, b}); - if (value) { - return; + if (p.narrow_type.bits() > 0) { + a = lossless_cast(p.narrow_type.with_lanes(a.type().lanes()), a); + } + if (!a.defined()) { + continue; + } + + if (init.defined() && (p.flags & Pattern::CombineInit)) { + value = call_overloaded_intrin(op->type, p.intrin, {init, a}); + if (value) { + return; + } + } else { + value = call_overloaded_intrin(op->type, p.intrin, {a}); + if (value) { + if (init.defined()) { + Value *x = value; + Value *y = codegen(init); + value = builder->CreateAdd(x, y); + } + return; + } } } else { - value = call_overloaded_intrin(op->type, p.intrin, {a, b}); - if (value) { - if (init.defined()) { - Value *x = value; - Value *y = codegen(init); - value = builder->CreateAdd(x, y); + Expr a = matches[0]; + Expr b = matches[1]; + if (p.flags & Pattern::SwapOperands) { + std::swap(a, b); + } + if (p.narrow_type.bits() > 0) { + a = lossless_cast(p.narrow_type.with_lanes(a.type().lanes()), a); + b = lossless_cast(p.narrow_type.with_lanes(b.type().lanes()), b); + } + if (!a.defined() || !b.defined()) { + continue; + } + + if (init.defined() && (p.flags & Pattern::CombineInit)) { + value = call_overloaded_intrin(op->type, p.intrin, {init, a, b}); + if (value) { + return; + } + } else { + value = call_overloaded_intrin(op->type, p.intrin, {a, b}); + if (value) { + if (init.defined()) { + Value *x = value; + Value *y = codegen(init); + value = builder->CreateAdd(x, y); + } + return; } - return; } } } } + // Rewrite non-native sum-of-absolute-difference variants to the native + // op. We support reducing to various types. We could consider supporting + // multiple reduction factors too, but in general we don't handle non-native + // reduction factors for VectorReduce nodes (yet?). + if (op->op == VectorReduce::Add && + factor == 8) { + const Cast *cast = op->value.as(); + const Call *call = cast ? cast->value.as() : nullptr; + if (call && + call->is_intrinsic(Call::absd) && + cast->type.element_of().can_represent(UInt(8)) && + (cast->type.is_int() || cast->type.is_uint()) && + call->args[0].type().element_of() == UInt(8)) { + + internal_assert(cast->type.element_of() != UInt(64)) << "Should have pattern-matched above\n"; + + // Cast to uint64 instead + Expr equiv = Cast::make(UInt(64, cast->value.type().lanes()), cast->value); + // Reduce on that to hit psadbw + equiv = VectorReduce::make(VectorReduce::Add, equiv, op->type.lanes()); + // Then cast that to the desired type + equiv = Cast::make(cast->type.with_lanes(equiv.type().lanes()), equiv); + codegen(equiv); + return; + } + } + CodeGen_Posix::codegen_vector_reduce(op, init); } @@ -663,7 +755,7 @@ void CodeGen_X86::visit(const Load *op) { const Ramp *ramp = op->index.as(); internal_assert(ramp) << "Expected AMXTile to have index ramp\n"; Value *ptr = codegen_buffer_pointer(op->name, op->type, ramp->base); - LoadInst *load = builder->CreateAlignedLoad(ptr->getType()->getPointerElementType(), ptr, llvm::Align(op->type.bytes())); + LoadInst *load = builder->CreateAlignedLoad(llvm_type_of(upgrade_type_for_storage(op->type)), ptr, llvm::Align(op->type.bytes())); add_tbaa_metadata(load, op->name, op->index); value = load; return; @@ -685,7 +777,10 @@ void CodeGen_X86::visit(const Store *op) { CodeGen_Posix::visit(op); } -string CodeGen_X86::mcpu() const { +string CodeGen_X86::mcpu_target() const { + // Perform an ad-hoc guess for the -mcpu given features. + // WARNING: this is used to drive -mcpu, *NOT* -mtune! + // The CPU choice here *WILL* affect -mattrs! if (target.has_feature(Target::AVX512_SapphireRapids)) { return "sapphirerapids"; } else if (target.has_feature(Target::AVX512_Cannonlake)) { @@ -707,6 +802,43 @@ string CodeGen_X86::mcpu() const { } } +string CodeGen_X86::mcpu_tune() const { + // Check if any explicit request for tuning exists. + switch (target.processor_tune) { // Please keep sorted. + case Target::Processor::AMDFam10: + return "amdfam10"; + case Target::Processor::BdVer1: + return "bdver1"; + case Target::Processor::BdVer2: + return "bdver2"; + case Target::Processor::BdVer3: + return "bdver3"; + case Target::Processor::BdVer4: + return "bdver4"; + case Target::Processor::BtVer1: + return "btver1"; + case Target::Processor::BtVer2: + return "btver2"; + case Target::Processor::K8: + return "k8"; + case Target::Processor::K8_SSE3: + return "k8-sse3"; + case Target::Processor::ZnVer1: + return "znver1"; + case Target::Processor::ZnVer2: + return "znver2"; + case Target::Processor::ZnVer3: + return "znver3"; + + case Target::Processor::ProcessorGeneric: + break; + } + internal_assert(target.processor_tune == Target::Processor::ProcessorGeneric && "The switch should be exhaustive."); + return mcpu_target(); // Detect "best" CPU from the enabled ISA's. +} + +// FIXME: we should lower everything here, instead of relying +// that -mcpu= (`mcpu_target()`) implies/sets features for us. string CodeGen_X86::mattrs() const { string features; string separator; diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index 5d46d60bf09e..f5840a0074b3 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -275,6 +275,16 @@ class Deinterleaver : public IRGraphMutator { return expr; } + Expr give_up_and_shuffle(const Expr &e) { + // Uh-oh, we don't know how to deinterleave this vector expression + // Make llvm do it + std::vector indices; + for (int i = 0; i < new_lanes; i++) { + indices.push_back(starting_lane + lane_stride * i); + } + return Shuffle::make({e}, indices); + } + Expr visit(const Variable *op) override { if (op->type.is_scalar()) { return op; @@ -302,13 +312,7 @@ class Deinterleaver : public IRGraphMutator { lane_stride == 3) { return Variable::make(t, op->name + ".lanes_2_of_3", op->image, op->param, op->reduction_domain); } else { - // Uh-oh, we don't know how to deinterleave this vector expression - // Make llvm do it - std::vector indices; - for (int i = 0; i < new_lanes; i++) { - indices.push_back(starting_lane + lane_stride * i); - } - return Shuffle::make({op}, indices); + return give_up_and_shuffle(op); } } } @@ -322,6 +326,17 @@ class Deinterleaver : public IRGraphMutator { } } + Expr visit(const Reinterpret *op) override { + if (op->type.is_scalar()) { + return op; + } else if (op->type.bits() != op->value.type().bits()) { + return give_up_and_shuffle(op); + } else { + Type t = op->type.with_lanes(new_lanes); + return Reinterpret::make(t, mutate(op->value)); + } + } + Expr visit(const Call *op) override { Type t = op->type.with_lanes(new_lanes); diff --git a/src/Derivative.cpp b/src/Derivative.cpp index c5ba26253367..08a1c617ca00 100644 --- a/src/Derivative.cpp +++ b/src/Derivative.cpp @@ -55,6 +55,7 @@ class ReverseAccumulationVisitor : public IRVisitor { void visit(const FloatImm *) override; void visit(const StringImm *) override; void visit(const Cast *op) override; + void visit(const Reinterpret *op) override; void visit(const Variable *op) override; void visit(const Add *op) override; void visit(const Sub *op) override; @@ -739,7 +740,6 @@ void ReverseAccumulationVisitor::propagate_adjoints( update_args, i); } - int count = 0; // Traverse the expressions in reverse order for (auto it = expr_list.rbegin(); it != expr_list.rend(); it++) { if (it->type().is_handle()) { @@ -748,7 +748,6 @@ void ReverseAccumulationVisitor::propagate_adjoints( } // Propagate adjoints it->accept(this); - count++; } } } @@ -836,6 +835,14 @@ void ReverseAccumulationVisitor::visit(const Cast *op) { } } +void ReverseAccumulationVisitor::visit(const Reinterpret *op) { + internal_assert(expr_adjoints.find(op) != expr_adjoints.end()); + Expr adjoint = expr_adjoints[op]; + + // bit manipulation -- has zero derivative. + accumulate(op->value, make_zero(op->type)); +} + void ReverseAccumulationVisitor::visit(const Variable *op) { internal_assert(expr_adjoints.find(op) != expr_adjoints.end()); Expr adjoint = expr_adjoints[op]; @@ -1169,8 +1176,7 @@ void ReverseAccumulationVisitor::visit(const Call *op) { accumulate(op->args[1], adjoint); } else if (op->is_intrinsic(Call::undef)) { // do nothing - } else if (op->is_intrinsic(Call::reinterpret) || - op->is_intrinsic(Call::bitwise_and) || + } else if (op->is_intrinsic(Call::bitwise_and) || op->is_intrinsic(Call::bitwise_not) || op->is_intrinsic(Call::bitwise_or) || op->is_intrinsic(Call::bitwise_xor) || @@ -1940,6 +1946,15 @@ Func Derivative::operator()(const Param<> ¶m) const { return it->second; } +Func Derivative::operator()(const std::string &name) const { + auto it = adjoints.find(FuncKey{name, -1}); + if (it == adjoints.end()) { + Internal::debug(1) << "Could not find name: " << name << "\n"; + return Func(); + } + return it->second; +} + Derivative propagate_adjoints(const Func &output, const Func &adjoint, const Region &output_bounds) { diff --git a/src/Derivative.h b/src/Derivative.h index 9d35b4f140ed..3100ab9c8868 100644 --- a/src/Derivative.h +++ b/src/Derivative.h @@ -35,6 +35,7 @@ class Derivative { Func operator()(const Func &func, int update_id = -1) const; Func operator()(const Buffer<> &buffer) const; Func operator()(const Param<> ¶m) const; + Func operator()(const std::string &name) const; private: const std::map adjoints; diff --git a/src/Elf.cpp b/src/Elf.cpp index 565ebd02343c..97cbfaa90af8 100644 --- a/src/Elf.cpp +++ b/src/Elf.cpp @@ -416,7 +416,7 @@ std::unique_ptr parse_object_internal(const char *data, size_t size) { internal_assert(to_relocate != obj->sections_end()); // TODO: This assert should work, but it seems like this // isn't a reliable test. We rely on the names intead. - //internal_assert(&*to_relocate == section_map[sh->sh_link]); + // internal_assert(&*to_relocate == section_map[sh->sh_link]); for (uint64_t i = 0; i < sh->sh_size / sh->sh_entsize; i++) { const char *rela_ptr = data + sh->sh_offset + i * sh->sh_entsize; internal_assert(data <= rela_ptr && rela_ptr + sizeof(Rela) <= data + size); diff --git a/src/EliminateBoolVectors.cpp b/src/EliminateBoolVectors.cpp index 2e63382f644c..cebfe0f0019b 100644 --- a/src/EliminateBoolVectors.cpp +++ b/src/EliminateBoolVectors.cpp @@ -136,6 +136,8 @@ class EliminateBoolVectors : public IRMutator { } } + // FIXME: what about Reinterpret? + Stmt visit(const Store *op) override { Expr predicate = op->predicate; if (!is_const_one(predicate)) { diff --git a/src/Error.cpp b/src/Error.cpp index ab7f66dad264..bf4e60013f75 100644 --- a/src/Error.cpp +++ b/src/Error.cpp @@ -33,8 +33,48 @@ bool exceptions_enabled() { #endif } +Error::Error(const char *msg) + : what_(new char[strlen(msg) + 1]) { + strcpy(what_, msg); +} + Error::Error(const std::string &msg) - : std::runtime_error(msg) { + : Error(msg.c_str()) { +} + +Error::Error(const Error &that) + : Error(that.what_) { +} + +Error &Error::operator=(const Error &that) { + if (this != &that) { + delete[] this->what_; + this->what_ = new char[strlen(that.what_) + 1]; + strcpy(this->what_, that.what_); + } + return *this; +} + +Error::Error(Error &&that) noexcept { + this->what_ = that.what_; + that.what_ = nullptr; +} + +Error &Error::operator=(Error &&that) noexcept { + if (this != &that) { + delete[] this->what_; + this->what_ = that.what_; + that.what_ = nullptr; + } + return *this; +} + +Error::~Error() { + delete[] what_; +} + +const char *Error::what() const noexcept { + return what_; } CompileError::CompileError(const std::string &msg) @@ -49,6 +89,18 @@ InternalError::InternalError(const std::string &msg) : Error(msg) { } +CompileError::CompileError(const char *msg) + : Error(msg) { +} + +RuntimeError::RuntimeError(const char *msg) + : Error(msg) { +} + +InternalError::InternalError(const char *msg) + : Error(msg) { +} + namespace Internal { // Force the classes to exist, even if exceptions are off @@ -96,11 +148,7 @@ ErrorReport::ErrorReport(const char *file, int line, const char *condition_strin } } -ErrorReport::~ErrorReport() -#if __cplusplus >= 201100 || _MSC_VER >= 1900 - noexcept(false) -#endif -{ +ErrorReport::~ErrorReport() noexcept(false) { if (!msg.str().empty() && msg.str().back() != '\n') { msg << "\n"; } @@ -123,6 +171,8 @@ ErrorReport::~ErrorReport() return; } + debug(1) << msg.str(); + #ifdef HALIDE_WITH_EXCEPTIONS if (std::uncaught_exceptions() > 0) { // This should never happen - evaluating one of the arguments @@ -138,7 +188,6 @@ ErrorReport::~ErrorReport() throw InternalError(msg.str()); } #else - std::cerr << msg.str(); abort(); #endif } diff --git a/src/Error.h b/src/Error.h index 4896946219cc..822ba4cf0d46 100644 --- a/src/Error.h +++ b/src/Error.h @@ -12,29 +12,58 @@ namespace Halide { /** Query whether Halide was compiled with exceptions. */ bool exceptions_enabled(); -/** A base class for Halide errors. */ -struct Error : public std::runtime_error { +/** A base class for Halide errors. + * + * Note that this deliberately does *not* descend from std::runtime_error, or + * even std::exception; unfortunately, std::runtime_error is not marked as + * DLLEXPORT on Windows, but Error needs to be marked as such, and mismatching + * DLLEXPORT annotations in a class inheritance hierarchy in this way can lead + * to ODR violations. Instead, we just attempt to replicate the API of + * runtime_error here. */ +struct HALIDE_EXPORT_SYMBOL Error { + Error() = delete; + // Give each class a non-inlined constructor so that the type // doesn't get separately instantiated in each compilation unit. - Error(const std::string &msg); + explicit Error(const char *msg); + explicit Error(const std::string &msg); + + Error(const Error &); + Error &operator=(const Error &); + Error(Error &&) noexcept; + Error &operator=(Error &&) noexcept; + + virtual ~Error(); + + virtual const char *what() const noexcept; + +private: + // Using a std::string here will cause MSVC to complain about the fact + // that class std::string isn't declared DLLEXPORT, even though the + // field is private; rather than suppress the warning, we'll just use + // an old-fashioned new-and-delete to keep it nice and clean. + char *what_; }; /** An error that occurs while running a JIT-compiled Halide pipeline. */ -struct RuntimeError : public Error { - RuntimeError(const std::string &msg); +struct HALIDE_EXPORT_SYMBOL RuntimeError : public Error { + explicit RuntimeError(const char *msg); + explicit RuntimeError(const std::string &msg); }; /** An error that occurs while compiling a Halide pipeline that Halide * attributes to a user error. */ -struct CompileError : public Error { - CompileError(const std::string &msg); +struct HALIDE_EXPORT_SYMBOL CompileError : public Error { + explicit CompileError(const char *msg); + explicit CompileError(const std::string &msg); }; /** An error that occurs while compiling a Halide pipeline that Halide * attributes to an internal compiler bug, or to an invalid use of * Halide's internals. */ -struct InternalError : public Error { - InternalError(const std::string &msg); +struct HALIDE_EXPORT_SYMBOL InternalError : public Error { + explicit InternalError(const char *msg); + explicit InternalError(const std::string &msg); }; /** CompileTimeErrorReporter is used at compile time (*not* runtime) when diff --git a/src/Expr.h b/src/Expr.h index b70d608d290b..ac0ec6521d68 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -33,6 +33,7 @@ enum class IRNodeType { StringImm, Broadcast, Cast, + Reinterpret, Variable, Add, Sub, diff --git a/src/ExternalCode.h b/src/ExternalCode.h index 7e75eabc9e53..8de876afc394 100644 --- a/src/ExternalCode.h +++ b/src/ExternalCode.h @@ -1,6 +1,8 @@ #ifndef HALIDE_EXTERNAL_CODE_H #define HALIDE_EXTERNAL_CODE_H +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE + #include #include "Expr.h" @@ -131,4 +133,10 @@ class ExternalCode { } // namespace Halide +#else + +#error "ExternalCode is deprecated in Halide 15 and will be removed in Halide 16" + +#endif // HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE + #endif diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 08df3ff7e39f..8fdcea73f34b 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -5,6 +5,31 @@ #include "IROperator.h" #include "Util.h" +/** \file Support extraction of AMX instructions. */ + +/** + * https://asciiflow.com/#/share/eJyVUkFugzAQ%2FMrKxwoRhdAkza23SmlySHvogQsBp7FkbGSbAoryiz6nr%2BlLugZDk6ghKvJhbXZmd2b3QEScUbIQBece4XFNFVmQQ0SqiCwegtCLSI1RMBtjZGhl8BIRAHh%2BeoFVbBSr4Pq36ZOiSOBpX5cDCEikSGhuipjzun0pmdnD4%2BqtwX9%2Ffg2cLmUcTML76WyO4VAtWJ%2Ff7kIkWMEJ6gbBae2%2F3q53OHBuFBz3TS1HodPqfvUO3%2F4wO7gQag07IXqVkCuZU4VzyApuWI5BAJkdZ0K1B2ZP2%2BwJ%2FEs%2BjhKY0EYViWFSaMAaO6kypBY1hLCtDRIvMTvsekmlsc2kiGgKMw2cxqkGIyEGjn%2FlzonoIMjPUibeQX5Q1bHGisbav%2FBh2kHW2ESzdlaZkqUltaFd9UZ25TnIrIOg%2Bb7vQykLnv661GysRSaSF1k78HkHcaSbntSReLAtTL%2FscOlaI9rxYaRzzgwUOTrZeOCokLzN0TDqRYvUqtFwB6Fvqco9S5r%2BBCiqsWmNLHabzny2Y7E4PyJHcvwBx0t%2BJw%3D%3D) + * + * LHS Matrix RHS Matrix + * + * K conceptually with AMX + * ┌────────┐ + * │12345678│ N N*4 + *M │ │ ┌──┐ ┌────────┐ + * └────────┘ │1 │ K/4│1234 │ + * │2 │ │5678 │ + * To properly multiply 2 matrices, the │3 │ └────────┘ + * AMX instructions perform many 4 byte K│4 │ + * dot products, this leads to a lot of │5 │ + * striding over 4 byte areas. │6 │ + * Normally the row of the LHS matrix, │7 │ + * 123... would multiply with the column │8 │ + * of the RHS matrix 123..., but with AMX └──┘ + * this column is split up into a matrix of columns / 4 byte and rows * 4. + * which then results in K/4 dot products per row. + * + */ + namespace Halide { namespace Internal { @@ -39,12 +64,52 @@ Type amx_op_type_result_type(AMXOpType op_ty) { } } +int amx_op_type_size(AMXOpType op_ty) { + switch (op_ty) { + case AMXOpType::Int8: + return 1; + case AMXOpType::Bfloat16: + return 2; + default: + internal_error << "Unexpected"; + return -1; + } +} + const auto wild_i32 = Variable::make(Int(32), "*"); const auto wild_i32x = Variable::make(Int(32, 0), "*"); Tile<1> get_1d_tile_index(const Expr &e) { if (const auto *r1 = e.as()) { - return {true, r1->base, {r1->stride}, {r1->lanes}}; + + const auto stride_var = Variable::make(Int(32), "stride"); + const auto v1 = Variable::make(Int(32), "v1"); + const auto v2 = Variable::make(Int(32), "v2"); + const auto v3 = Variable::make(Int(32), "v3"); + + Expr patterns[] = { + ((v1 * stride_var) + v2) * v3, + v3 * ((v1 * stride_var) + v2), + (v2 + (v1 * stride_var)) * v3, + v3 * (v2 + (v1 * stride_var)), + }; + + std::map matches; + for (const auto &pattern : patterns) { + if (expr_match(pattern, r1->base, matches)) { + auto stride = std::move(matches["stride"]); + // stride must be a constant in order to not be confused with v1 + if (stride.as()) { + return {true, r1->base, {std::move(stride)}, {r1->lanes}}; + } + + // if stride wasn't a constant then v1 could possibly be the stride if constant + auto v1_expr = std::move(matches["v1"]); + if (v1_expr.as()) { + return {true, r1->base, {std::move(v1_expr)}, {r1->lanes}}; + } + } + } } return {}; @@ -143,6 +208,169 @@ Tile<3> get_3d_tile_index(const Expr &e) { return {true, base, {x_stride, 0, r_stride}, {x_tile, y_tile, r_tile}}; } +/** + * \brief Get the 3d rhs tile index configuration + * + * \param e index expression + * \param element_width the width of the elements, 1 for u8/i8, 2 for bf16 + * \return Tile<3> the tile configuration found + * + * The pattern which is getting matched looks roughly like + * `broadcast(ramp(0, 1, r), x*y) / broadcast(4, x*y*r) + optional(broadcast(base, x*y*r)) * broadcast(8, x*y*r) + + * broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) + + * broadcast(ramp(broadcast(_, r), broadcast(4, r), x) , y)` + */ +Tile<3> get_3d_rhs_tile_index(const Expr &e, int element_width) { + const auto *sub = e.as(); + const Add *add_lhs = nullptr; + + // there's not always a sub pattern + // This depends on whether we have an ImageParam or a Buffer + if (!sub) { + add_lhs = e.as(); + } else { + add_lhs = sub->a.as(); + } + + if (!add_lhs) { + return {}; + } + + // The right hand side of the add expression is used for retrieving the dimensions of the matrix. + // obtain the x, y, r dimensions + // this expr looks like below, the shape of `add_lhs->a` can be seen further down below + // broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) + broadcast(ramp(broadcast(base, r), broadcast(4, r), x) , y) + const Add *dim_expr = add_lhs->b.as(); + + if (!dim_expr) { + return {}; + } + + // broadcast(ramp(broadcast(_, r), broadcast(4, r), x), y) + const Broadcast *base_stride_bc = dim_expr->b.as(); + + if (!base_stride_bc) { + return {}; + } + + int tile_y = base_stride_bc->lanes; + + // broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) + const Mod *mod = dim_expr->a.as(); + + if (!mod) { + return {}; + } + + // broadcast(ramp(0, 1, r), x*y) + const Broadcast *bc_ramp = mod->a.as(); + + if (!bc_ramp) { + return {}; + } + + int tile_xy = bc_ramp->lanes; + int tile_x = tile_xy / tile_y; + + // ramp(0, 1, r) + const Ramp *r_ramp = bc_ramp->value.as(); + + if (!r_ramp) { + return {}; + } + + int tile_r = r_ramp->lanes; + + // get the base and stride + // ramp(broadcast(_, r), broadcast(4, r), x) + const Ramp *base_stride_ramp = base_stride_bc->value.as(); + + if (!base_stride_ramp) { + return {}; + } + + // broadcast(_, r) + const Broadcast *base_bc = base_stride_ramp->base.as(); + + if (!base_bc) { + return {}; + } + + Expr base = base_bc->value; + Expr stride; + + bool found_stride = false; + + // the following pattern will match the following shape + // broadcast(ramp(0, 1, k), x*y) / broadcast(4, x*y*k) * broadcast(_, x*y*k) + // where the stride is marked by _. + + // this stride pattern can occur if `tile_r` is the same size as `acc` + auto stride_pattern = Broadcast::make(Ramp::make(0, 1, tile_r), tile_x * tile_y) / Broadcast::make((4 / element_width), tile_x * tile_y * tile_r) * Broadcast::make(wild_i32, tile_x * tile_y * tile_r); + + std::vector results{}; + if (expr_match(stride_pattern, add_lhs->a, results)) { + found_stride = true; + stride = std::move(results[0]); + } + + // This pattern is similar to the above except with an additional offset to iterate over the tiles in the k dimension + // (broadcast(ramp(0, 1, k), m * n) / broadcast(4, m*n*k) + _) * broadcast(_, m*n*k) + // here the first _ marks the base and the second _ the stride. + if (!found_stride) { + stride_pattern = (Broadcast::make(Ramp::make(0, 1, tile_r), tile_x * tile_y) / Broadcast::make((4 / element_width), tile_x * tile_y * tile_r) + wild_i32) * Broadcast::make(wild_i32, tile_x * tile_y * tile_r); + if (expr_match(stride_pattern, add_lhs->a, results)) { + found_stride = true; + stride = std::move(results[1]); + base = std::move(results[0]) * stride + base; + } + } + + if (!found_stride) { + return {}; + } + + return {true, base, {stride, 0, 0}, {tile_x, tile_y, tile_r}}; +} + +struct BaseStride { + bool result{false}; + Expr base{}; + Expr stride{}; +}; + +BaseStride get_rhs_tile_index(const Expr &index, int element_width, int tile_x, int tile_y, int tile_r) { + const auto rhs_tile2 = get_2d_tile_index(index); + + if (!rhs_tile2.result) { + const auto rhs_tile1 = get_1d_tile_index(index); + + if (!rhs_tile1.result) { + auto rhs_tile3 = get_3d_rhs_tile_index(index, element_width); + if (rhs_tile3.extent[0] != tile_x || rhs_tile3.extent[1] != tile_y || rhs_tile3.extent[2] != tile_r) { + return {}; + } + + return {true, rhs_tile3.base, rhs_tile3.stride[0] * element_width}; + } else { + if (rhs_tile1.extent[0] != tile_y * tile_r) { + return {}; + } + + // times 4 because of the rhs layout, each vector used by AMX is 4 bytes in size. + // For the 4 gets divided by the element width which means each vector has 4 elements in u8/i8 and + // 2 elements for bf16. + return {true, rhs_tile1.base, rhs_tile1.stride[0] * (4 / element_width)}; + } + } else { + if (tile_y != rhs_tile2.extent[0] || tile_r != rhs_tile2.extent[1]) { + return {}; + } + + return {true, rhs_tile2.base, rhs_tile2.stride[0]}; + } +} + struct Matmul { bool result = false; Stmt stmt; @@ -197,17 +425,41 @@ Matmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType op_t const auto *lhs_load = matches[0].as(); const auto *rhs_broadcast = matches[1].as(); - if (!lhs_load || !rhs_broadcast) { + + const Cast *rhs_cast = nullptr; + + if (lhs_load && !rhs_broadcast) { + // now working on a larger k dimension + // with a K dimension of 4 (or 2) with bf16 all the elements in the right-hand matrix are + // layed out in a way that multiplying with a column can be done in a single dot product. + // Therefore the indexing can be reused with a broadcast, + // with higher K dimensions this can no longer be done and the broadcast won't exist. + // ┌──┐ + // │1 │ + // │2 │ + // │3 │ ┌────────┐ + // │4 │ │1234 │ + // │5 │ │5678 │ + // │6 │ └────────┘ + // │7 │ + // │8 │ + // └──┘ + rhs_cast = matches[1].as(); + } else { + rhs_cast = rhs_broadcast->value.as(); + } + + if (!lhs_load || !rhs_cast) { return {}; } - const auto *rhs_cast = rhs_broadcast->value.as(); + if (rhs_cast) { - if (op_type == AMXOpType::Int8) { - if (!(rhs_cast->value.type().element_of() == Int(8) || rhs_cast->value.type().element_of() == UInt(8))) { - user_assert(false) << "Expected rhs cast of i8/u8"; - } - } else { // AMXOpType::Bfloat16 - user_assert(rhs_cast->value.type().element_of() == BFloat(16)) << "Expected rhs cast of bf16"; + bool is_i8_u8 = rhs_cast->value.type().element_of() == Int(8) || rhs_cast->value.type().element_of() == UInt(8); + bool is_bf16 = rhs_cast->value.type().element_of() == BFloat(16); + + if ((op_type == AMXOpType::Int8 && !is_i8_u8) || (op_type == AMXOpType::Bfloat16 && !is_bf16)) { + user_error << "Expected rhs type of " << (op_type == AMXOpType::Int8 ? "i8/u8" : "bf16") + << ", got " << rhs_cast->value.type() << " instead.\nIn Expression: " << Expr(rhs_cast); } } else { return {}; @@ -232,40 +484,20 @@ Matmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType op_t Expr rhs_base; Expr rhs_stride; - const auto rhs_tile2 = get_2d_tile_index(rhs_load->index); - if (!rhs_tile2.result) { - const auto rhs_tile1 = get_1d_tile_index(rhs_load->index); - - if (!rhs_tile1.result) { - return {}; - } - - if (rhs_tile1.extent[0] != tile_y * tile_r) { - return {}; - } + auto opt_base_stride = get_rhs_tile_index(rhs_load->index, amx_op_type_size(op_type), tile_x, tile_y, tile_r); - rhs_base = rhs_tile1.base; - rhs_stride = rhs_tile1.stride[0]; - } else { - if (tile_y != rhs_tile2.extent[0] || tile_r != rhs_tile2.extent[1]) { - return {}; - } - - rhs_base = rhs_tile2.base; - rhs_stride = rhs_tile2.stride[0]; + if (!opt_base_stride.result) { + return {}; } + rhs_base = opt_base_stride.base; + rhs_stride = opt_base_stride.stride; + if (op->index.type().lanes() != tile_x * tile_y || factor != tile_r) { return {}; } -#if LLVM_VERSION < 130 - user_assert(op_type != AMXOpType::Bfloat16 && - lhs_load->type.is_int() && rhs_cast->value.type().is_int()) - << "LLVM 13 or above is required for unsigned or float AMX instructions"; -#endif - // {rows, colbytes, var, index} auto lhs_var = Variable::make(Handle(), lhs_load->name); const auto &lhs_load_type = lhs_load->type; @@ -276,7 +508,8 @@ Matmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType op_t auto rhs_var = Variable::make(Handle(), rhs_load->name); const auto &rhs_load_type = rhs_load->type; auto rhs_type = rhs_load_type.with_lanes(1024 / element_width); - auto rhs = Call::make(rhs_type, "tile_load", {1, tile_y * tile_r * element_width, rhs_var, rhs_base * element_width, rhs_stride * tile_y * element_width}, Call::Intrinsic); + + auto rhs = Call::make(rhs_type, "tile_load", {tile_r / (4 / element_width), tile_y * 4, rhs_var, rhs_base * element_width, rhs_stride}, Call::Intrinsic); auto res_type = amx_op_type_result_type(op_type); // {rows, colbytes, acc, out, lhs, rhs} @@ -416,6 +649,7 @@ class ExtractTileOperations : public IRMutator { found_tile_x = matmul.tile_x; found_tile_y = matmul.tile_y; found_tile_r = matmul.tile_r; + return matmul.stmt; } @@ -430,7 +664,7 @@ class ExtractTileOperations : public IRMutator { } // Otherwise there is some other operation using the allocation, so we cannot use the AMX instructions - user_assert(false) << "Found non-tile operations for AMX tile allocation"; + user_error << "Found non-tile operations for AMX tile allocation"; return op; } }; diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 79a1f8d61bd5..0a33b85822aa 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -410,6 +410,14 @@ class FindIntrinsics : public IRMutator { saturating_sub(x, y), op->type.is_uint() && is_x_same_uint) || + // Saturating narrow patterns. + rewrite(max(min(x, upper), lower), + saturating_cast(op->type, x)) || + + rewrite(min(x, upper), + saturating_cast(op->type, x), + is_uint(x)) || + // Averaging patterns // // We have a slight preference for rounding_halving_add over @@ -447,18 +455,10 @@ class FindIntrinsics : public IRMutator { rounding_halving_add(x, y), is_x_same_int_or_uint) || - rewrite(halving_add(widening_sub(x, y), 1), - rounding_halving_sub(x, y), - is_x_same_int_or_uint) || - rewrite(rounding_shift_right(widening_add(x, y), 1), rounding_halving_add(x, y), is_x_same_int_or_uint) || - rewrite(rounding_shift_right(widening_sub(x, y), 1), - rounding_halving_sub(x, y), - is_x_same_int_or_uint) || - // Multiply-keep-high-bits patterns. rewrite(max(min(shift_right(widening_mul(x, y), z), upper), lower), mul_shift_right(x, y, cast(unsigned_type, z)), @@ -521,10 +521,6 @@ class FindIntrinsics : public IRMutator { halving_sub(x, y), is_x_same_int_or_uint) || - rewrite(rounding_shift_right(cast(op_type_wide, widening_sub(x, y)), 1), - rounding_halving_sub(x, y), - is_x_same_int_or_uint) || - false) { internal_assert(rewrite.result.type() == op->type) << "Rewrite changed type: " << Expr(op) << " -> " << rewrite.result << "\n"; @@ -587,18 +583,66 @@ class FindIntrinsics : public IRMutator { return rewrite.result; } + const int bits = op->type.bits(); + const auto is_x_same_int = op->type.is_int() && is_int(x, bits); + const auto is_x_same_uint = op->type.is_uint() && is_uint(x, bits); + const auto is_x_same_int_or_uint = is_x_same_int || is_x_same_uint; + auto x_y_same_sign = (is_int(x) == is_int(y)) || (is_uint(x) && is_uint(y)); + Type unsigned_type = op->type.with_code(halide_type_uint); + const auto is_x_wider_int_or_uint = (op->type.is_int() && is_int(x, 2 * bits)) || (op->type.is_uint() && is_uint(x, 2 * bits)); + Type opposite_type = op->type.is_int() ? op->type.with_code(halide_type_uint) : op->type.with_code(halide_type_int); + const auto is_x_wider_opposite_int = (op->type.is_int() && is_uint(x, 2 * bits)) || (op->type.is_uint() && is_int(x, 2 * bits)); + + if ( + // Saturating patterns. + rewrite(saturating_cast(op->type, widening_add(x, y)), + saturating_add(x, y), + is_x_same_int_or_uint) || + rewrite(saturating_cast(op->type, widening_sub(x, y)), + saturating_sub(x, y), + is_x_same_int_or_uint) || + rewrite(saturating_cast(op->type, shift_right(widening_mul(x, y), z)), + mul_shift_right(x, y, cast(unsigned_type, z)), + is_x_same_int_or_uint && x_y_same_sign && is_uint(z)) || + rewrite(saturating_cast(op->type, rounding_shift_right(widening_mul(x, y), z)), + rounding_mul_shift_right(x, y, cast(unsigned_type, z)), + is_x_same_int_or_uint && x_y_same_sign && is_uint(z)) || + // We can remove unnecessary widening if we are then performing a saturating narrow. + // This is similar to the logic inside `visit_min_or_max`. + (((bits <= 32) && + // Examples: + // i8_sat(int16(i8)) -> i8 + // u8_sat(uint16(u8)) -> u8 + rewrite(saturating_cast(op->type, cast(op->type.widen(), x)), + x, + is_x_same_int_or_uint)) || + ((bits <= 16) && + // Examples: + // i8_sat(int32(i16)) -> i8_sat(i16) + // u8_sat(uint32(u16)) -> u8_sat(u16) + (rewrite(saturating_cast(op->type, cast(op->type.widen().widen(), x)), + saturating_cast(op->type, x), + is_x_wider_int_or_uint) || + // Examples: + // i8_sat(uint32(u16)) -> i8_sat(u16) + // u8_sat(int32(i16)) -> i8_sat(i16) + rewrite(saturating_cast(op->type, cast(opposite_type.widen().widen(), x)), + saturating_cast(op->type, x), + is_x_wider_opposite_int) || + false))) || + false) { + return mutate(rewrite.result); + } + if (no_overflow(op->type)) { // clang-format off if (rewrite(halving_add(x + y, 1), rounding_halving_add(x, y)) || rewrite(halving_add(x, y + 1), rounding_halving_add(x, y)) || rewrite(halving_add(x + 1, y), rounding_halving_add(x, y)) || - rewrite(halving_add(x - y, 1), rounding_halving_sub(x, y)) || - rewrite(halving_sub(x + 1, y), rounding_halving_sub(x, y)) || rewrite(halving_add(x, 1), rounding_shift_right(x, 1)) || rewrite(shift_right(x + y, 1), halving_add(x, y)) || rewrite(shift_right(x - y, 1), halving_sub(x, y)) || rewrite(rounding_shift_right(x + y, 1), rounding_halving_add(x, y)) || - rewrite(rounding_shift_right(x - y, 1), rounding_halving_sub(x, y)) || false) { return mutate(rewrite.result); } @@ -901,6 +945,47 @@ Expr lower_saturating_sub(const Expr &a, const Expr &b) { return simplify(clamp(a, a.type().min() + max(b, 0), a.type().max() + min(b, 0))) - b; } +Expr lower_saturating_cast(const Type &t, const Expr &a) { + // For float to float, guarantee infinities are always pinned to range. + if (t.is_float() && a.type().is_float()) { + if (t.bits() < a.type().bits()) { + return cast(t, clamp(a, t.min(), t.max())); + } else { + return clamp(cast(t, a), t.min(), t.max()); + } + } else if (a.type() != t) { + // Limits for Int(2^n) or UInt(2^n) are not exactly representable in Float(2^n) + if (a.type().is_float() && !t.is_float() && t.bits() >= a.type().bits()) { + Expr e = max(a, t.min()); // min values turn out to be always representable + + // This line depends on t.max() rounding upward, which should always + // be the case as it is one less than a representable value, thus + // the one larger is always the closest. + e = select(e >= cast(e.type(), t.max()), t.max(), cast(t, e)); + return e; + } else { + Expr min_bound; + if (!a.type().is_uint()) { + min_bound = lossless_cast(a.type(), t.min()); + } + Expr max_bound = lossless_cast(a.type(), t.max()); + + Expr e; + if (min_bound.defined() && max_bound.defined()) { + e = clamp(a, min_bound, max_bound); + } else if (min_bound.defined()) { + e = max(a, min_bound); + } else if (max_bound.defined()) { + e = min(a, max_bound); + } else { + e = a; + } + return cast(t, std::move(e)); + } + } + return a; +} + Expr lower_halving_add(const Expr &a, const Expr &b) { internal_assert(a.type() == b.type()); // Borrowed from http://aggregate.org/MAGIC/#Average%20of%20Integers @@ -909,19 +994,30 @@ Expr lower_halving_add(const Expr &a, const Expr &b) { Expr lower_halving_sub(const Expr &a, const Expr &b) { internal_assert(a.type() == b.type()); - return (a >> 1) - (b >> 1) - (((b & 1) - (a & 1) + 1) >> 1); + Expr e = rounding_halving_add(a, ~b); + if (a.type().is_uint()) { + // An explanation in 8-bit: + // (x - y) / 2 + // = (x + 256 - y) / 2 - 128 + // = (x + (255 - y) + 1) / 2 - 128 + // = (x + ~y + 1) / 2 - 128 + // = rounding_halving_add(x, ~y) - 128 + // = rounding_halving_add(x, ~y) + 128 (due to 2s-complement wrap-around) + return e + make_const(e.type(), (uint64_t)1 << (a.type().bits() - 1)); + } else { + // For 2s-complement signed integers, negating is done by flipping the + // bits and adding one, so: + // (x - y) / 2 + // = (x + (-y)) / 2 + // = (x + (~y + 1)) / 2 + // = rounding_halving_add(x, ~y) + return e; + } } -// TODO: These should using rounding_shift_right, but lowering that -// results in double widening and the simplifier doesn't fix it. Expr lower_rounding_halving_add(const Expr &a, const Expr &b) { internal_assert(a.type() == b.type()); - return (a >> 1) + (b >> 1) + (((a & 1) + (b & 1) + 1) >> 1); -} - -Expr lower_rounding_halving_sub(const Expr &a, const Expr &b) { - internal_assert(a.type() == b.type()); - return (a >> 1) - (b >> 1) + (((a & 1) - (b & 1) + 1) >> 1); + return halving_add(a, b) + ((a ^ b) & 1); } Expr lower_sorted_avg(const Expr &a, const Expr &b) { @@ -1019,6 +1115,9 @@ Expr lower_intrinsic(const Call *op) { } else if (op->is_intrinsic(Call::saturating_sub)) { internal_assert(op->args.size() == 2); return lower_saturating_sub(op->args[0], op->args[1]); + } else if (op->is_intrinsic(Call::saturating_cast)) { + internal_assert(op->args.size() == 1); + return lower_saturating_cast(op->type, op->args[0]); } else if (op->is_intrinsic(Call::widening_shift_left)) { internal_assert(op->args.size() == 2); return lower_widening_shift_left(op->args[0], op->args[1]); @@ -1040,9 +1139,6 @@ Expr lower_intrinsic(const Call *op) { } else if (op->is_intrinsic(Call::rounding_halving_add)) { internal_assert(op->args.size() == 2); return lower_rounding_halving_add(op->args[0], op->args[1]); - } else if (op->is_intrinsic(Call::rounding_halving_sub)) { - internal_assert(op->args.size() == 2); - return lower_rounding_halving_sub(op->args[0], op->args[1]); } else if (op->is_intrinsic(Call::rounding_mul_shift_right)) { internal_assert(op->args.size() == 3); return lower_rounding_mul_shift_right(op->args[0], op->args[1], op->args[2]); diff --git a/src/FindIntrinsics.h b/src/FindIntrinsics.h index b42599d2330e..07e639117252 100644 --- a/src/FindIntrinsics.h +++ b/src/FindIntrinsics.h @@ -22,11 +22,11 @@ Expr lower_rounding_shift_right(const Expr &a, const Expr &b); Expr lower_saturating_add(const Expr &a, const Expr &b); Expr lower_saturating_sub(const Expr &a, const Expr &b); +Expr lower_saturating_cast(const Type &t, const Expr &a); Expr lower_halving_add(const Expr &a, const Expr &b); Expr lower_halving_sub(const Expr &a, const Expr &b); Expr lower_rounding_halving_add(const Expr &a, const Expr &b); -Expr lower_rounding_halving_sub(const Expr &a, const Expr &b); Expr lower_mul_shift_right(const Expr &a, const Expr &b, const Expr &q); Expr lower_rounding_mul_shift_right(const Expr &a, const Expr &b, const Expr &q); diff --git a/src/FlattenNestedRamps.cpp b/src/FlattenNestedRamps.cpp index 809a5e445053..0a57af2b58db 100644 --- a/src/FlattenNestedRamps.cpp +++ b/src/FlattenNestedRamps.cpp @@ -96,6 +96,7 @@ class FlattenRamps : public IRMutator { // Compute the number of elements loaded int extent = (int)((max_constant_offset / stride) + 1); + // If we're gathering from a very large range, it // might be better to just do the gather rather than // doing a big dense load and then shuffling. We @@ -120,16 +121,54 @@ class FlattenRamps : public IRMutator { } }; +/** Simplify bit concatenation of interleaved loads to vector reinterprets of + * dense loads. Must be done to both vectors and scalars after flattening nested + * ramps, because it can expand a flat ramp into a wider one. */ +class SimplifyConcatBits : public IRMutator { + using IRMutator::visit; + + Expr visit(const Call *op) override { + if (op->is_intrinsic(Call::concat_bits)) { + // Simplify a concat of a load of adjacent bits to a reinterpret of a load of a small vector. + const Load *l0 = op->args[0].as(); + bool ok = true; + const int n = (int)(op->args.size()); + for (int i = 0; ok && i < n; i++) { + const Load *li = op->args[i].as(); + ok &= (li != nullptr); + if (!ok) { + break; + } + const Ramp *r = li->index.as(); + Expr base = r ? r->base : li->index; + ok &= (is_const_one(li->predicate) && + l0->name == li->name && + can_prove(l0->index + i == li->index) && + (r == nullptr || is_const(r->stride, n))); + } + + if (ok) { + internal_assert(l0); + const Ramp *r0 = l0->index.as(); + int new_lanes = (r0 ? r0->lanes : 1) * n; + Expr base = r0 ? r0->base : l0->index; + Expr idx = Ramp::make(base, 1, new_lanes); + return mutate(Reinterpret::make(op->type, Load::make(l0->type.with_lanes(n * l0->type.lanes()), l0->name, idx, l0->image, l0->param, const_true(new_lanes), l0->alignment))); + } + } + + return IRMutator::visit(op); + } +}; + } // namespace Stmt flatten_nested_ramps(const Stmt &s) { - FlattenRamps flatten_ramps; - return flatten_ramps.mutate(s); + return SimplifyConcatBits().mutate(FlattenRamps().mutate(s)); } Expr flatten_nested_ramps(const Expr &e) { - FlattenRamps flatten_ramps; - return flatten_ramps.mutate(e); + return SimplifyConcatBits().mutate(FlattenRamps().mutate(e)); } } // namespace Internal diff --git a/src/Float16.h b/src/Float16.h index 9373662982fb..93727795cf4d 100644 --- a/src/Float16.h +++ b/src/Float16.h @@ -127,7 +127,7 @@ static_assert(sizeof(float16_t) == 2, "float16_t should occupy two bytes"); } // namespace Halide template<> -HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { +HALIDE_ALWAYS_INLINE constexpr halide_type_t halide_type_of() { return halide_type_t(halide_type_float, 16); } @@ -254,7 +254,7 @@ static_assert(sizeof(bfloat16_t) == 2, "bfloat16_t should occupy two bytes"); } // namespace Halide template<> -HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { +HALIDE_ALWAYS_INLINE constexpr halide_type_t halide_type_of() { return halide_type_t(halide_type_bfloat, 16); } diff --git a/src/Func.cpp b/src/Func.cpp index ebdcb44a8c84..2051db5668da 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -10,6 +10,7 @@ #include "ApplySplit.h" #include "Argument.h" #include "Associativity.h" +#include "Callable.h" #include "CodeGen_LLVM.h" #include "Debug.h" #include "ExprUsesVar.h" @@ -59,6 +60,14 @@ Func::Func(const string &name) : func(unique_name(name)) { } +Func::Func(const Type &required_type, int required_dims, const string &name) + : func({required_type}, required_dims, unique_name(name)) { +} + +Func::Func(const std::vector &required_types, int required_dims, const string &name) + : func(required_types, required_dims, unique_name(name)) { +} + Func::Func() : func(make_entity_name(this, "Halide:.*:Func", 'f')) { } @@ -188,13 +197,33 @@ void Func::define_extern(const std::string &function_name, } /** Get the types of the buffers returned by an extern definition. */ -const std::vector &Func::output_types() const { - return func.output_types(); +const Type &Func::type() const { + const auto &types = defined() ? func.output_types() : func.required_types(); + if (types.empty()) { + user_error << "Can't call Func::type on Func \"" << name() + << "\" because it is undefined or has no type requirements.\n"; + } else if (types.size() > 1) { + user_error << "Can't call Func::type on Func \"" << name() + << "\" because it returns a Tuple.\n"; + } + return types[0]; +} + +const std::vector &Func::types() const { + const auto &types = defined() ? func.output_types() : func.required_types(); + user_assert(!types.empty()) + << "Can't call Func::types on Func \"" << name() + << "\" because it is undefined or has no type requirements.\n"; + return types; } /** Get the number of outputs this function has. */ int Func::outputs() const { - return func.outputs(); + const auto &types = defined() ? func.output_types() : func.required_types(); + user_assert(!types.empty()) + << "Can't call Func::outputs on Func \"" << name() + << "\" because it is undefined or has no type requirements.\n"; + return (int)types.size(); } /** Get the name of the extern function called for an extern @@ -204,10 +233,11 @@ const std::string &Func::extern_function_name() const { } int Func::dimensions() const { - if (!defined()) { - return 0; - } - return func.dimensions(); + const int dims = defined() ? func.dimensions() : func.required_dimensions(); + user_assert(dims != AnyDims) + << "Can't call Func::dimensions on Func \"" << name() + << "\" because it is undefined or has no dimension requirements.\n"; + return dims; } FuncRef Func::operator()(vector args) const { @@ -232,7 +262,9 @@ std::pair Func::add_implicit_vars(vector &args) const { placeholder_pos = (int)(iter - args.begin()); int i = 0; iter = args.erase(iter); - while ((int)args.size() < dimensions()) { + // It's important to use func.dimensions() here, *not* this->dimensions(), + // since the latter can return the Func's required dimensions rather than its actual dimensions. + while ((int)args.size() < func.dimensions()) { Internal::debug(2) << "Adding implicit var " << i << " to call to " << name() << "\n"; iter = args.insert(iter, Var::implicit(i++)); iter++; @@ -240,9 +272,9 @@ std::pair Func::add_implicit_vars(vector &args) const { } } - if (defined() && args.size() != (size_t)dimensions()) { + if (defined() && args.size() != (size_t)func.dimensions()) { user_error << "Func \"" << name() << "\" was called with " - << args.size() << " arguments, but was defined with " << dimensions() << "\n"; + << args.size() << " arguments, but was defined with " << func.dimensions() << "\n"; } return {placeholder_pos, count}; @@ -263,7 +295,9 @@ std::pair Func::add_implicit_vars(vector &args) const { placeholder_pos = (int)(iter - args.begin()); int i = 0; iter = args.erase(iter); - while ((int)args.size() < dimensions()) { + // It's important to use func.dimensions() here, *not* this->dimensions(), + // since the latter can return the Func's required dimensions rather than its actual dimensions. + while ((int)args.size() < func.dimensions()) { Internal::debug(2) << "Adding implicit var " << i << " to call to " << name() << "\n"; iter = args.insert(iter, Var::implicit(i++)); iter++; @@ -271,9 +305,9 @@ std::pair Func::add_implicit_vars(vector &args) const { } } - if (defined() && args.size() != (size_t)dimensions()) { + if (defined() && args.size() != (size_t)func.dimensions()) { user_error << "Func \"" << name() << "\" was called with " - << args.size() << " arguments, but was defined with " << dimensions() << "\n"; + << args.size() << " arguments, but was defined with " << func.dimensions() << "\n"; } return {placeholder_pos, count}; @@ -948,7 +982,7 @@ Func Stage::rfactor(vector> preserved) { } if (!prover_result.xs[i].var.empty()) { - Expr prev_val = Call::make(intm.output_types()[i], func_name, + Expr prev_val = Call::make(intm.types()[i], func_name, f_store_args, Call::CallType::Halide, FunctionPtr(), i); replacements.emplace(prover_result.xs[i].var, prev_val); @@ -2915,9 +2949,11 @@ Stage FuncRef::operator=(const FuncRef &e) { } } +namespace { + // Inject a suitable base-case definition given an update // definition. This is a helper for FuncRef::operator+= and co. -Func define_base_case(const Internal::Function &func, const vector &a, const Tuple &e) { +Func define_base_case(const Internal::Function &func, const vector &a, const vector &rhs, int init_val) { Func f(func); if (func.has_pure_definition()) { @@ -2936,24 +2972,32 @@ Func define_base_case(const Internal::Function &func, const vector &a, con } } - f(pure_args) = e; + const auto &required_types = func.required_types(); + internal_assert(required_types.empty() || required_types.size() == rhs.size()); + + vector init_values(rhs.size()); + for (size_t i = 0; i < rhs.size(); ++i) { + // If we have required types, cast the init_val to that type instead of the rhs type + const Type &t = required_types.empty() ? rhs[i].type() : required_types[i]; + init_values[i] = cast(t, init_val); + } + + f(pure_args) = Tuple(init_values); return f; } -Func define_base_case(const Internal::Function &func, const vector &a, const Expr &e) { - return define_base_case(func, a, Tuple(e)); -} +} // namespace template Stage FuncRef::func_ref_update(const Tuple &e, int init_val) { + // Don't do this: we want to allow the RHS to be implicitly cast to the type of LHS. + // func.check_types(e); + internal_assert(e.size() > 1); - vector init_values(e.size()); - for (int i = 0; i < (int)init_values.size(); ++i) { - init_values[i] = cast(e[i].type(), init_val); - } - vector expanded_args = args_with_implicit_vars(e.as_vector()); - FuncRef self_ref = define_base_case(func, expanded_args, Tuple(init_values))(expanded_args); + const vector &rhs = e.as_vector(); + const vector expanded_args = args_with_implicit_vars(rhs); + FuncRef self_ref = define_base_case(func, expanded_args, rhs, init_val)(expanded_args); vector values(e.size()); for (int i = 0; i < (int)values.size(); ++i) { @@ -2964,8 +3008,12 @@ Stage FuncRef::func_ref_update(const Tuple &e, int init_val) { template Stage FuncRef::func_ref_update(Expr e, int init_val) { - vector expanded_args = args_with_implicit_vars({e}); - FuncRef self_ref = define_base_case(func, expanded_args, cast(e.type(), init_val))(expanded_args); + // Don't do this: we want to allow the RHS to be implicitly cast to the type of LHS. + // func.check_types(e); + + const vector rhs = {e}; + const vector expanded_args = args_with_implicit_vars(rhs); + FuncRef self_ref = define_base_case(func, expanded_args, rhs, init_val)(expanded_args); return self_ref = BinaryOp()(Expr(self_ref), e); } @@ -3157,26 +3205,25 @@ void Func::infer_input_bounds(JITUserContext *context, Buffer<> im(func.output_types()[i], nullptr, sizes); outputs[i] = std::move(im); } - Realization r(outputs); + Realization r(std::move(outputs)); infer_input_bounds(context, r, target, param_map); } OutputImageParam Func::output_buffer() const { - user_assert(defined()) - << "Can't access output buffer of undefined Func.\n"; - user_assert(func.output_buffers().size() == 1) + const auto &ob = func.output_buffers(); + + user_assert(ob.size() == 1) << "Can't call Func::output_buffer on Func \"" << name() << "\" because it returns a Tuple.\n"; - return OutputImageParam(func.output_buffers()[0], Argument::OutputBuffer, *this); + return OutputImageParam(ob[0], Argument::OutputBuffer, *this); } vector Func::output_buffers() const { - user_assert(defined()) - << "Can't access output buffers of undefined Func.\n"; + const auto &ob = func.output_buffers(); - vector bufs(func.output_buffers().size()); + vector bufs(ob.size()); for (size_t i = 0; i < bufs.size(); i++) { - bufs[i] = OutputImageParam(func.output_buffers()[i], Argument::OutputBuffer, *this); + bufs[i] = OutputImageParam(ob[i], Argument::OutputBuffer, *this); } return bufs; } @@ -3309,33 +3356,6 @@ void set_handler(A &a, B b) { } } // namespace -// Deprecated setters for JIT handlers -void Func::set_error_handler(void (*handler)(void *, const char *)) { - set_handler(jit_handlers().custom_error, handler); -} - -void Func::set_custom_allocator(void *(*cust_malloc)(void *, size_t), - void (*cust_free)(void *, void *)) { - set_handler(jit_handlers().custom_malloc, cust_malloc); - set_handler(jit_handlers().custom_free, cust_free); -} - -void Func::set_custom_do_par_for(int (*cust_do_par_for)(void *, int (*)(void *, int, uint8_t *), int, int, uint8_t *)) { - set_handler(jit_handlers().custom_do_par_for, cust_do_par_for); -} - -void Func::set_custom_do_task(int (*cust_do_task)(void *, int (*)(void *, int, uint8_t *), int, uint8_t *)) { - set_handler(jit_handlers().custom_do_task, cust_do_task); -} - -void Func::set_custom_trace(int (*trace_fn)(void *, const halide_trace_event_t *)) { - set_handler(jit_handlers().custom_trace, trace_fn); -} - -void Func::set_custom_print(void (*cust_print)(void *, const char *)) { - set_handler(jit_handlers().custom_print, cust_print); -} - void Func::add_custom_lowering_pass(IRMutator *pass, std::function deleter) { pipeline().add_custom_lowering_pass(pass, std::move(deleter)); } @@ -3382,4 +3402,8 @@ void Func::compile_jit(const Target &target) { pipeline().compile_jit(target); } +Callable Func::compile_to_callable(const std::vector &args, const Target &target) { + return pipeline().compile_to_callable(args, target); +} + } // namespace Halide diff --git a/src/Func.h b/src/Func.h index 76ee46e459d3..8e9fff31915b 100644 --- a/src/Func.h +++ b/src/Func.h @@ -446,22 +446,6 @@ class Stage { Stage &hexagon(const VarOrRVar &x = Var::outermost()); - HALIDE_ATTRIBUTE_DEPRECATED("Call prefetch() with the two-var form instead.") - Stage &prefetch(const Func &f, const VarOrRVar &var, int offset = 1, - PrefetchBoundStrategy strategy = PrefetchBoundStrategy::GuardWithIf) { - return prefetch(f, var, var, offset, strategy); - } - HALIDE_ATTRIBUTE_DEPRECATED("Call prefetch() with the two-var form instead.") - Stage &prefetch(const Internal::Parameter ¶m, const VarOrRVar &var, int offset = 1, - PrefetchBoundStrategy strategy = PrefetchBoundStrategy::GuardWithIf) { - return prefetch(param, var, var, offset, strategy); - } - template - HALIDE_ATTRIBUTE_DEPRECATED("Call prefetch() with the two-var form instead.") - Stage &prefetch(const T &image, VarOrRVar var, int offset = 1, - PrefetchBoundStrategy strategy = PrefetchBoundStrategy::GuardWithIf) { - return prefetch(image.parameter(), var, var, offset, strategy); - } Stage &prefetch(const Func &f, const VarOrRVar &at, const VarOrRVar &from, Expr offset = 1, PrefetchBoundStrategy strategy = PrefetchBoundStrategy::GuardWithIf); Stage &prefetch(const Internal::Parameter ¶m, const VarOrRVar &at, const VarOrRVar &from, Expr offset = 1, @@ -736,6 +720,19 @@ class Func { /** Declare a new undefined function with the given name */ explicit Func(const std::string &name); + /** Declare a new undefined function with the given name. + * The function will be constrained to represent Exprs of required_type. + * If required_dims is not AnyDims, the function will be constrained to exactly + * that many dimensions. */ + explicit Func(const Type &required_type, int required_dims, const std::string &name); + + /** Declare a new undefined function with the given name. + * If required_types is not empty, the function will be constrained to represent + * Tuples of the same arity and types. (If required_types is empty, there is no constraint.) + * If required_dims is not AnyDims, the function will be constrained to exactly + * that many dimensions. */ + explicit Func(const std::vector &required_types, int required_dims, const std::string &name); + /** Declare a new undefined function with an * automatically-generated unique name */ Func(); @@ -1058,34 +1055,19 @@ class Func { */ void compile_jit(const Target &target = get_jit_target_from_environment()); - /** Deprecated variants of the above that use a void pointer - * instead of a JITUserContext pointer. */ - // @{ - HALIDE_ATTRIBUTE_DEPRECATED("Custom handlers should by set by modifying the struct returned by jit_handlers()") - void set_error_handler(void (*handler)(void *, const char *)); - HALIDE_ATTRIBUTE_DEPRECATED("Custom handlers should by set by modifying the struct returned by jit_handlers()") - void set_custom_allocator(void *(*malloc)(void *, size_t), - void (*free)(void *, void *)); - HALIDE_ATTRIBUTE_DEPRECATED("Custom handlers should by set by modifying the struct returned by jit_handlers()") - void set_custom_do_task( - int (*custom_do_task)(void *, int (*)(void *, int, uint8_t *), - int, uint8_t *)); - HALIDE_ATTRIBUTE_DEPRECATED("Custom handlers should by set by modifying the struct returned by jit_handlers()") - void set_custom_do_par_for( - int (*custom_do_par_for)(void *, int (*)(void *, int, uint8_t *), int, - int, uint8_t *)); - HALIDE_ATTRIBUTE_DEPRECATED("Custom handlers should by set by modifying the struct returned by jit_handlers()") - void set_custom_trace(int (*trace_fn)(void *, const halide_trace_event_t *)); - - HALIDE_ATTRIBUTE_DEPRECATED("Custom handlers should by set by modifying the struct returned by jit_handlers()") - void set_custom_print(void (*handler)(void *, const char *)); - // @} - /** Get a struct containing the currently set custom functions * used by JIT. This can be mutated. Changes will take effect the * next time this Func is realized. */ JITHandlers &jit_handlers(); + /** Eagerly jit compile the function to machine code and return a callable + * struct that behaves like a function pointer. The calling convention + * will exactly match that of an AOT-compiled version of this Func + * with the same Argument list. + */ + Callable compile_to_callable(const std::vector &args, + const Target &target = get_jit_target_from_environment()); + /** Add a custom pass to be used during lowering. It is run after * all other lowering passes. Can be used to verify properties of * the lowered Stmt, instrument it with extra code, or otherwise @@ -1234,19 +1216,42 @@ class Func { DeviceAPI device_api = DeviceAPI::Host); // @} - /** Get the types of the outputs of this Func. */ - const std::vector &output_types() const; + /** Get the type(s) of the outputs of this Func. + * + * It is not legal to call type() unless the Func has non-Tuple elements. + * + * If the Func isn't yet defined, and was not specified with required types, + * a runtime error will occur. + * + * If the Func isn't yet defined, but *was* specified with required types, + * the requirements will be returned. */ + // @{ + const Type &type() const; + const std::vector &types() const; + // @} + + HALIDE_ATTRIBUTE_DEPRECATED("Func::output_type() is deprecated; use Func::type() instead.") + const Type &output_type() const { + return type(); + } + HALIDE_ATTRIBUTE_DEPRECATED("Func::output_types() is deprecated; use Func::types() instead.") + const std::vector &output_types() const { + return types(); + } /** Get the number of outputs of this Func. Corresponds to the - * size of the Tuple this Func was defined to return. */ + * size of the Tuple this Func was defined to return. + * If the Func isn't yet defined, but was specified with required types, + * the number of outputs specified in the requirements will be returned. */ int outputs() const; /** Get the name of the extern function called for an extern * definition. */ const std::string &extern_function_name() const; - /** The dimensionality (number of arguments) of this - * function. Zero if the function is not yet defined. */ + /** The dimensionality (number of arguments) of this function. + * If the Func isn't yet defined, but was specified with required dimensionality, + * the dimensionality specified in the requirements will be returned. */ int dimensions() const; /** Construct either the left-hand-side of a definition, or a call @@ -1441,7 +1446,7 @@ class Func { * factor does not provably divide the extent. */ Func &split(const VarOrRVar &old, const VarOrRVar &outer, const VarOrRVar &inner, const Expr &factor, TailStrategy tail = TailStrategy::Auto); - /** Join two dimensions into a single fused dimenion. The fused + /** Join two dimensions into a single fused dimension. The fused * dimension covers the product of the extents of the inner and * outer dimensions given. */ Func &fuse(const VarOrRVar &inner, const VarOrRVar &outer, const VarOrRVar &fused); @@ -1946,55 +1951,7 @@ class Func { Func &hexagon(const VarOrRVar &x = Var::outermost()); /** Prefetch data written to or read from a Func or an ImageParam by a - * subsequent loop iteration, at an optionally specified iteration offset. - * 'var' specifies at which loop level the prefetch calls should be inserted. - * The final argument specifies how prefetch of region outside bounds - * should be handled. - * - * For example, consider this pipeline: - \code - Func f, g; - Var x, y; - f(x, y) = x + y; - g(x, y) = 2 * f(x, y); - \endcode - * - * The following schedule: - \code - f.compute_root(); - g.prefetch(f, x, 2, PrefetchBoundStrategy::NonFaulting); - \endcode - * - * will inject prefetch call at the innermost loop of 'g' and generate - * the following loop nest: - * for y = ... - * for x = ... - * f(x, y) = x + y - * for y = .. - * for x = ... - * prefetch(&f[x + 2, y], 1, 16); - * g(x, y) = 2 * f(x, y) - */ - // @{ - HALIDE_ATTRIBUTE_DEPRECATED("Call prefetch() with the two-var form instead.") - Func &prefetch(const Func &f, const VarOrRVar &var, int offset = 1, - PrefetchBoundStrategy strategy = PrefetchBoundStrategy::GuardWithIf) { - return prefetch(f, var, var, offset, strategy); - } - HALIDE_ATTRIBUTE_DEPRECATED("Call prefetch() with the two-var form instead.") - Func &prefetch(const Internal::Parameter ¶m, const VarOrRVar &var, int offset = 1, - PrefetchBoundStrategy strategy = PrefetchBoundStrategy::GuardWithIf) { - return prefetch(param, var, var, offset, strategy); - } - template - HALIDE_ATTRIBUTE_DEPRECATED("Call prefetch() with the two-var form instead.") - Func &prefetch(const T &image, VarOrRVar var, int offset = 1, - PrefetchBoundStrategy strategy = PrefetchBoundStrategy::GuardWithIf) { - return prefetch(image, var, var, offset, strategy); - } - // @} - - /** prefetch() is a more fine-grained version of prefetch(), which allows + * subsequent loop iteration, at an optionally specified iteration offset. You may specify * specification of different vars for the location of the prefetch() instruction * vs. the location that is being prefetched: * @@ -2005,6 +1962,9 @@ class Func { * If 'at' and 'from' are distinct vars, then 'from' must be at a nesting level outside 'at.' * Note that the value for 'offset' applies only to 'from', not 'at'. * + * The final argument specifies how prefetch of region outside bounds + * should be handled. + * * For example, consider this pipeline: \code Func f, g; diff --git a/src/Function.cpp b/src/Function.cpp index f7eded59e824..c644d53c2784 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -30,6 +30,7 @@ typedef map DeepCopyMap; struct FunctionContents; namespace { + // Weaken all the references to a particular Function to break // reference cycles. Also count the number of references found. class WeakenFunctionPtrs : public IRMutator { @@ -58,6 +59,7 @@ class WeakenFunctionPtrs : public IRMutator { : func(f) { } }; + } // namespace struct FunctionContents { @@ -65,6 +67,22 @@ struct FunctionContents { std::string origin_name; std::vector output_types; + /** Optional type constraints on the Function: + * - If empty, there are no constraints. + * - If size == 1, the Func is only allowed to have values of Expr with that type + * - If size > 1, the Func is only allowed to have values of Tuple with those types + * + * Note that when this is nonempty, then output_types should match + * required_types for all defined Functions. + */ + std::vector required_types; + + /** Optional dimension constraints on the Function: + * - If required_dims == AnyDims, there are no constraints. + * - Otherwise, the Function's dimensionality must exactly match required_dims. + */ + int required_dims = AnyDims; + // The names of the dimensions of the Function. Corresponds to the // LHS of the pure definition if there is one. Is also the initial // stage of the dims and storage_dims. Used to identify dimensions @@ -306,9 +324,100 @@ Function::Function(const std::string &n) { contents->origin_name = n; } +Function::Function(const std::vector &required_types, int required_dims, const std::string &n) + : Function(n) { + user_assert(required_dims >= AnyDims); + contents->required_types = required_types; + contents->required_dims = required_dims; +} + +namespace { + +template +struct PrintTypeList { + const std::vector &list_; + + explicit PrintTypeList(const std::vector &list) + : list_(list) { + } + + friend std::ostream &operator<<(std::ostream &s, const PrintTypeList &self) { + const size_t n = self.list_.size(); + if (n != 1) { + s << "("; + } + const char *comma = ""; + for (const auto &t : self.list_) { + if constexpr (std::is_same::value) { + s << comma << t; + } else { + s << comma << t.type(); + } + comma = ", "; + } + if (n != 1) { + s << ")"; + } + return s; + } +}; + +bool types_match(const std::vector &types, const std::vector &exprs) { + size_t n = types.size(); + if (n != exprs.size()) { + return false; + } + for (size_t i = 0; i < n; i++) { + if (types[i] != exprs[i].type()) { + return false; + } + } + return true; +} + +} // namespace + +void Function::check_types(const Expr &e) const { + check_types(std::vector{e}); +} + +void Function::check_types(const Tuple &t) const { + check_types(t.as_vector()); +} + +void Function::check_types(const Type &t) const { + check_types(std::vector{t}); +} + +void Function::check_types(const std::vector &exprs) const { + if (!contents->required_types.empty()) { + user_assert(types_match(contents->required_types, exprs)) + << "Func \"" << name() << "\" is constrained to only hold values of type " << PrintTypeList(contents->required_types) + << " but is defined with values of type " << PrintTypeList(exprs) << ".\n"; + } +} + +void Function::check_types(const std::vector &types) const { + if (!contents->required_types.empty()) { + user_assert(contents->required_types == types) + << "Func \"" << name() << "\" is constrained to only hold values of type " << PrintTypeList(contents->required_types) + << " but is defined with values of type " << PrintTypeList(types) << ".\n"; + } +} + +void Function::check_dims(int dims) const { + if (contents->required_dims != AnyDims) { + user_assert(contents->required_dims == dims) + << "Func \"" << name() << "\" is constrained to have exactly " << contents->required_dims + << " dimensions, but is defined with " << dims << " dimensions.\n"; + } +} + +namespace { + // Return deep-copy of ExternFuncArgument 'src' -ExternFuncArgument deep_copy_extern_func_argument_helper( - const ExternFuncArgument &src, DeepCopyMap &copied_map) { +ExternFuncArgument deep_copy_extern_func_argument_helper(const ExternFuncArgument &src, + DeepCopyMap &copied_map) { ExternFuncArgument copy; copy.arg_type = src.arg_type; copy.buffer = src.buffer; @@ -330,6 +439,8 @@ ExternFuncArgument deep_copy_extern_func_argument_helper( return copy; } +} // namespace + void Function::deep_copy(const FunctionPtr ©, DeepCopyMap &copied_map) const { internal_assert(copy.defined() && contents.defined()) << "Cannot deep-copy undefined Function\n"; @@ -456,6 +567,8 @@ void Function::define(const vector &args, vector values) { << "In pure definition of Func \"" << name() << "\":\n" << "Func is already defined.\n"; + check_types(values); + check_dims((int)args.size()); contents->args = args; std::vector init_def_args; @@ -485,12 +598,30 @@ void Function::define(const vector &args, vector values) { contents->output_types[i] = values[i].type(); } - for (size_t i = 0; i < values.size(); i++) { + if (!contents->required_types.empty()) { + // Just a reality check; mismatches here really should have been caught earlier + internal_assert(contents->required_types == contents->output_types); + } + if (contents->required_dims != AnyDims) { + // Just a reality check; mismatches here really should have been caught earlier + internal_assert(contents->required_dims == (int)args.size()); + } + + if (contents->output_buffers.empty()) { + create_output_buffers(contents->output_types, (int)args.size()); + } +} + +void Function::create_output_buffers(const std::vector &types, int dims) const { + internal_assert(contents->output_buffers.empty()); + internal_assert(!types.empty() && dims != AnyDims); + + for (size_t i = 0; i < types.size(); i++) { string buffer_name = name(); - if (values.size() > 1) { + if (types.size() > 1) { buffer_name += '.' + std::to_string((int)i); } - Parameter output(values[i].type(), true, args.size(), buffer_name); + Parameter output(types[i], true, dims, buffer_name); contents->output_buffers.push_back(output); } } @@ -703,6 +834,8 @@ void Function::define_extern(const std::string &function_name, const std::vector &args, NameMangling mangling, DeviceAPI device_api) { + check_types(types); + check_dims((int)args.size()); user_assert(!has_pure_definition() && !has_update_definition()) << "In extern definition for Func \"" << name() << "\":\n" @@ -788,13 +921,25 @@ bool Function::is_pure_arg(const std::string &name) const { } int Function::dimensions() const { - return args().size(); + return (int)args().size(); +} + +int Function::outputs() const { + return (int)output_types().size(); } const std::vector &Function::output_types() const { return contents->output_types; } +const std::vector &Function::required_types() const { + return contents->required_types; +} + +int Function::required_dimensions() const { + return contents->required_dims; +} + const std::vector &Function::values() const { static const std::vector empty; if (has_pure_definition()) { @@ -813,6 +958,18 @@ const FuncSchedule &Function::schedule() const { } const std::vector &Function::output_buffers() const { + if (!contents->output_buffers.empty()) { + return contents->output_buffers; + } + + // If types and dims are already specified, we can go ahead and create + // the output buffer(s) even if the Function has no pure definition yet. + if (!contents->required_types.empty() && contents->required_dims != AnyDims) { + create_output_buffers(contents->required_types, contents->required_dims); + return contents->output_buffers; + } + + user_error << "Can't access output buffer(s) of undefined Func \"" << name() << "\".\n"; return contents->output_buffers; } diff --git a/src/Function.h b/src/Function.h index ce8a76ef4f17..0cbdf4688f5a 100644 --- a/src/Function.h +++ b/src/Function.h @@ -17,6 +17,7 @@ namespace Halide { struct ExternFuncArgument; +class Tuple; class Var; @@ -57,6 +58,13 @@ class Function { /** Construct a new function with the given name */ explicit Function(const std::string &n); + /** Construct a new function with the given name, + * with a requirement that it can only represent Expr(s) of the given type(s), + * and must have exactly the give nnumber of dimensions. + * required_types.empty() means there are no constraints on the type(s). + * required_dims == AnyDims means there are no constraints on the dimensions. */ + explicit Function(const std::vector &required_types, int required_dims, const std::string &n); + /** Construct a Function from an existing FunctionContents pointer. Must be non-null */ explicit Function(const FunctionPtr &); @@ -125,13 +133,17 @@ class Function { int dimensions() const; /** Get the number of outputs. */ - int outputs() const { - return (int)output_types().size(); - } + int outputs() const; /** Get the types of the outputs. */ const std::vector &output_types() const; + /** Get the type constaints on the outputs (if any). */ + const std::vector &required_types() const; + + /** Get the dimensionality constaints on the outputs (if any). */ + int required_dimensions() const; + /** Get the right-hand-side of the pure definition. Returns an * empty vector if there is no pure definition. */ const std::vector &values() const; @@ -292,6 +304,22 @@ class Function { /** Return true iff the name matches one of the Function's pure args. */ bool is_pure_arg(const std::string &name) const; + + /** If the Function has type requirements, check that the given argument + * is compatible with them. If not, assert-fail. (If there are no type requirements, do nothing.) */ + void check_types(const Expr &e) const; + void check_types(const Tuple &t) const; + void check_types(const Type &t) const; + void check_types(const std::vector &exprs) const; + void check_types(const std::vector &types) const; + + /** If the Function has dimension requirements, check that the given argument + * is compatible with them. If not, assert-fail. (If there are no dimension requirements, do nothing.) */ + void check_dims(int dims) const; + + /** Define the output buffers. If the Function has types specified, this can be called at + * any time. If not, it can only be called for a Function with a pure definition. */ + void create_output_buffers(const std::vector &types, int dims) const; }; /** Deep copy an entire Function DAG. */ diff --git a/src/FuseGPUThreadLoops.cpp b/src/FuseGPUThreadLoops.cpp index cd65a32618f5..1f1b9f05d4ec 100644 --- a/src/FuseGPUThreadLoops.cpp +++ b/src/FuseGPUThreadLoops.cpp @@ -1188,7 +1188,7 @@ class ExtractRegisterAllocations : public IRMutator { Stmt rewrap(Stmt body, const string &loop_var) { for (RegisterAllocation &alloc : allocations) { - if ((!loop_var.empty() && ends_with(alloc.loop_var, loop_var)) | + if ((!loop_var.empty() && ends_with(alloc.loop_var, loop_var)) || (loop_var.empty() && alloc.loop_var.empty())) { body = Allocate::make(alloc.name, alloc.type, alloc.memory_type, {alloc.size}, const_true(), body); } diff --git a/src/Generator.cpp b/src/Generator.cpp index 724e36dd6fd4..0950396807d1 100644 --- a/src/Generator.cpp +++ b/src/Generator.cpp @@ -7,36 +7,73 @@ #include #include -#include "BoundaryConditions.h" #include "CompilerLogger.h" -#include "Derivative.h" #include "Generator.h" #include "IRPrinter.h" #include "Module.h" #include "Simplify.h" +#ifdef HALIDE_ALLOW_GENERATOR_BUILD_METHOD +#pragma message "Support for Generator build() methods has been removed in Halide version 15." +#endif + namespace Halide { +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE GeneratorContext::GeneratorContext(const Target &target, +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API bool auto_schedule, const MachineParams &machine_params, - std::shared_ptr externs_map, - std::shared_ptr value_tracker) +#else + const AutoschedulerParams &autoscheduler_params, +#endif + std::shared_ptr externs_map) : target_(target), +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API auto_schedule_(auto_schedule), machine_params_(machine_params), - externs_map_(std::move(externs_map)), - value_tracker_(std::move(value_tracker)) { +#else + autoscheduler_params_(autoscheduler_params), +#endif + externs_map_(std::move(externs_map)) { } +#endif // HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API GeneratorContext::GeneratorContext(const Target &target, bool auto_schedule, const MachineParams &machine_params) - : GeneratorContext(target, - auto_schedule, - machine_params, - std::make_shared(), - std::make_shared()) { + : target_(target), + auto_schedule_(auto_schedule), + machine_params_(machine_params) { +} +#else +GeneratorContext::GeneratorContext(const Target &target) + : target_(target), + autoscheduler_params_() { +} + +GeneratorContext::GeneratorContext(const Target &target, + const AutoschedulerParams &autoscheduler_params) + : target_(target), + autoscheduler_params_(autoscheduler_params) { +} +#endif + +GeneratorContext GeneratorContext::with_target(const Target &t) const { +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE + return GeneratorContext(t, auto_schedule_, machine_params_, externs_map_); +#else + return GeneratorContext(t, auto_schedule_, machine_params_); +#endif +#else +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE + return GeneratorContext(t, autoscheduler_params_, externs_map_); +#else + return GeneratorContext(t, autoscheduler_params_); +#endif +#endif } namespace Internal { @@ -75,18 +112,14 @@ bool is_valid_name(const std::string &n) { return false; } } + // prohibit this specific string so that we can use it for + // passing GeneratorParams in Python. + if (n == "generator_params") { + return false; + } return true; } -std::string compute_base_path(const std::string &output_dir, - const std::string &function_name, - const std::string &file_base_name) { - std::vector namespaces; - std::string simple_name = extract_namespaces(function_name, namespaces); - std::string base_path = output_dir + "/" + (file_base_name.empty() ? simple_name : file_base_name); - return base_path; -} - std::map compute_output_files(const Target &target, const std::string &base_path, const std::set &outputs) { @@ -99,17 +132,9 @@ std::map compute_output_files(const Target &target, return output_files; } -Argument to_argument(const Internal::Parameter ¶m) { - return Argument(param.name(), - param.is_buffer() ? Argument::InputBuffer : Argument::InputScalar, - param.type(), - param.dimensions(), - param.get_argument_estimates()); -} - Func make_param_func(const Parameter &p, const std::string &name) { internal_assert(p.is_buffer()); - Func f(name + "_im"); + Func f(p.type(), p.dimensions(), name + "_im"); auto b = p.buffer(); if (b.defined()) { // If the Parameter has an explicit BufferPtr set, bind directly to it @@ -140,106 +165,6 @@ std::vector parse_halide_type_list(const std::string &types) { return result; } -/** - * ValueTracker is an internal utility class that attempts to track and flag certain - * obvious Stub-related errors at Halide compile time: it tracks the constraints set - * on any Parameter-based argument (i.e., Input and Output) to - * ensure that incompatible values aren't set. - * - * e.g.: if a Generator A requires stride[0] == 1, - * and Generator B uses Generator A via stub, but requires stride[0] == 4, - * we should be able to detect this at Halide compilation time, and fail immediately, - * rather than producing code that fails at runtime and/or runs slowly due to - * vectorization being unavailable. - * - * We do this by tracking the active values at entrance and exit to all user-provided - * Generator methods (generate()/schedule()); if we ever find more than two unique - * values active, we know we have a potential conflict. ("two" here because the first - * value is the default value for a given constraint.) - * - * Note that this won't catch all cases: - * -- JIT compilation has no way to check for conflicts at the top-level - * -- constraints that match the default value (e.g. if dim(0).set_stride(1) is the - * first value seen by the tracker) will be ignored, so an explicit requirement set - * this way can be missed - * - * Nevertheless, this is likely to be much better than nothing when composing multiple - * layers of Stubs in a single fused result. - */ -class ValueTracker { -private: - std::map>> values_history; - const size_t max_unique_values; - -public: - explicit ValueTracker(size_t max_unique_values = 2) - : max_unique_values(max_unique_values) { - } - void track_values(const std::string &name, const std::vector &values); -}; - -void ValueTracker::track_values(const std::string &name, const std::vector &values) { - std::vector> &history = values_history[name]; - if (history.empty()) { - for (const auto &value : values) { - history.push_back({value}); - } - return; - } - - internal_assert(history.size() == values.size()) - << "Expected values of size " << history.size() - << " but saw size " << values.size() - << " for name " << name << "\n"; - - // For each item, see if we have a new unique value - for (size_t i = 0; i < values.size(); ++i) { - Expr oldval = history[i].back(); - Expr newval = values[i]; - if (oldval.defined() && newval.defined()) { - if (can_prove(newval == oldval)) { - continue; - } - } else if (!oldval.defined() && !newval.defined()) { - // Expr::operator== doesn't work with undefined - // values, but they are equal for our purposes here. - continue; - } - history[i].push_back(newval); - // If we exceed max_unique_values, fail immediately. - // TODO: could be useful to log all the entries that - // overflow max_unique_values before failing. - // TODO: this could be more helpful about labeling the values - // that have multiple setttings. - if (history[i].size() > max_unique_values) { - std::ostringstream o; - o << "Saw too many unique values in ValueTracker[" + std::to_string(i) + "]; " - << "expected a maximum of " << max_unique_values << ":\n"; - for (const auto &e : history[i]) { - o << " " << e << "\n"; - } - user_error << o.str(); - } - } -} - -std::vector parameter_constraints(const Parameter &p) { - internal_assert(p.defined()); - std::vector values; - values.emplace_back(p.host_alignment()); - if (p.is_buffer()) { - for (int i = 0; i < p.dimensions(); ++i) { - values.push_back(p.min_constraint(i)); - values.push_back(p.extent_constraint(i)); - values.push_back(p.stride_constraint(i)); - } - } else { - values.push_back(p.min_value()); - values.push_back(p.max_value()); - } - return values; -} - class StubEmitter { public: StubEmitter(std::ostream &dest, @@ -282,11 +207,18 @@ class StubEmitter { std::vector out; for (auto *p : in) { // These are always propagated specially. +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API if (p->name() == "target" || p->name() == "auto_schedule" || p->name() == "machine_params") { continue; } +#else + if (p->name() == "target" || + p->name() == "autoscheduler") { + continue; + } +#endif if (p->is_synthetic_param()) { continue; } @@ -324,7 +256,11 @@ void StubEmitter::emit_generator_params_struct() { indent_level++; std::string comma = ""; for (auto *p : v) { - stream << get_indent() << comma << p->get_c_type() << " " << p->name() << "\n"; + std::string c_type = p->get_c_type(); + if (c_type == "AutoschedulerParams") { + c_type = "const AutoschedulerParams&"; + } + stream << get_indent() << comma << c_type << " " << p->name() << "\n"; comma = ", "; } indent_level--; @@ -341,25 +277,6 @@ void StubEmitter::emit_generator_params_struct() { stream << "\n"; } - stream << get_indent() << "inline HALIDE_NO_USER_CODE_INLINE Halide::Internal::GeneratorParamsMap to_generator_params_map() const {\n"; - indent_level++; - stream << get_indent() << "return {\n"; - indent_level++; - std::string comma = ""; - for (auto *p : v) { - stream << get_indent() << comma << "{\"" << p->name() << "\", "; - if (p->is_looplevel_param()) { - stream << p->name() << "}\n"; - } else { - stream << p->call_to_string(p->name()) << "}\n"; - } - comma = ", "; - } - indent_level--; - stream << get_indent() << "};\n"; - indent_level--; - stream << get_indent() << "}\n"; - indent_level--; stream << get_indent() << "};\n"; stream << "\n"; @@ -438,11 +355,17 @@ void StubEmitter::emit() { for (auto *output : outputs) { std::string c_type = output->get_c_type(); const bool is_func = (c_type == "Func"); - std::string getter = is_func ? "get_outputs" : "get_output_buffers<" + c_type + ">"; - std::string getter_suffix = output->is_array() ? "" : ".at(0)"; + std::string getter = "generator->output_func(\"" + output->name() + "\")"; + if (!is_func) { + getter = c_type + "::to_output_buffers(" + getter + ", generator)"; + } + if (!output->is_array()) { + getter = getter + ".at(0)"; + } + out_info.push_back({output->name(), output->is_array() ? "std::vector<" + c_type + ">" : c_type, - getter + "(\"" + output->name() + "\")" + getter_suffix}); + getter}); if (c_type != "Func") { all_outputs_are_func = false; } @@ -463,6 +386,7 @@ void StubEmitter::emit() { stream << "\n"; stream << get_indent() << "#include \n"; + stream << get_indent() << "#include \n"; stream << get_indent() << "#include \n"; stream << get_indent() << "#include \n"; stream << get_indent() << "#include \n"; @@ -474,7 +398,7 @@ void StubEmitter::emit() { stream << "namespace halide_register_generator {\n"; stream << "namespace " << generator_registered_name << "_ns {\n"; - stream << "extern std::unique_ptr factory(const Halide::GeneratorContext& context);\n"; + stream << "extern std::unique_ptr factory(const Halide::GeneratorContext& context);\n"; stream << "} // namespace halide_register_generator\n"; stream << "} // namespace " << generator_registered_name << "\n"; stream << "\n"; @@ -618,29 +542,48 @@ void StubEmitter::emit() { stream << get_indent() << ")\n"; stream << get_indent() << "{\n"; indent_level++; - stream << get_indent() << "using Stub = Halide::Internal::GeneratorStub;\n"; - stream << get_indent() << "Stub stub(\n"; - indent_level++; - stream << get_indent() << "context,\n"; - stream << get_indent() << "halide_register_generator::" << generator_registered_name << "_ns::factory,\n"; - stream << get_indent() << "generator_params.to_generator_params_map(),\n"; - stream << get_indent() << "{\n"; - indent_level++; - for (auto *input : inputs) { - stream << get_indent() << "Stub::to_stub_input_vector(inputs." << input->name() << ")"; - stream << ",\n"; + stream << get_indent() << "std::shared_ptr generator = halide_register_generator::" << generator_registered_name << "_ns::factory(context);\n"; + for (auto *p : generator_params) { + stream << get_indent(); + if (p->is_looplevel_param()) { + stream << "generator->set_generatorparam_value("; + } else { + stream << "generator->set_generatorparam_value("; + } + stream << "\"" << p->name() << "\", "; + if (p->is_looplevel_param()) { + stream << "generator_params." << p->name(); + } else { + stream << p->call_to_string("generator_params." + p->name()); + } + stream << ");\n"; } - indent_level--; - stream << get_indent() << "}\n"; - indent_level--; - stream << get_indent() << ");\n"; + for (auto *p : inputs) { + stream << get_indent() << "generator->bind_input(" + << "\"" << p->name() << "\", "; + if (p->kind() == ArgInfoKind::Buffer) { + stream << "Halide::Internal::StubInputBuffer<>::to_parameter_vector(inputs." << p->name() << ")"; + } else { + // Func or Expr + if (!p->is_array()) { + stream << "{"; + } + stream << "inputs." << p->name(); + if (!p->is_array()) { + stream << "}"; + } + } + stream << ");\n"; + } + + stream << get_indent() << "generator->build_pipeline();\n"; stream << get_indent() << "return {\n"; indent_level++; for (const auto &out : out_info) { - stream << get_indent() << "stub." << out.getter << ",\n"; + stream << get_indent() << out.getter << ",\n"; } - stream << get_indent() << "stub.generator->context().get_target()\n"; + stream << get_indent() << "generator->context().target()\n"; indent_level--; stream << get_indent() << "};\n"; indent_level--; @@ -691,68 +634,6 @@ void StubEmitter::emit() { stream << get_indent() << "#endif // " << guard.str() << "\n"; } -GeneratorStub::GeneratorStub(const GeneratorContext &context, - const GeneratorFactory &generator_factory) - : generator(generator_factory(context)) { -} - -GeneratorStub::GeneratorStub(const GeneratorContext &context, - const GeneratorFactory &generator_factory, - const GeneratorParamsMap &generator_params, - const std::vector> &inputs) - : GeneratorStub(context, generator_factory) { - generate(generator_params, inputs); -} - -// Return a vector of all Outputs of this Generator; non-array outputs are returned -// as a vector-of-size-1. This method is primarily useful for code that needs -// to iterate through the outputs of unknown, arbitrary Generators (e.g., -// the Python bindings). -std::vector> GeneratorStub::generate(const GeneratorParamsMap &generator_params, - const std::vector> &inputs) { - generator->set_generator_param_values(generator_params); - generator->ensure_configure_has_been_called(); - generator->set_inputs_vector(inputs); - Pipeline p = generator->build_pipeline(); - - std::vector> v; - GeneratorParamInfo &pi = generator->param_info(); -#ifdef HALIDE_ALLOW_GENERATOR_BUILD_METHOD - if (!pi.outputs().empty()) { - for (auto *output : pi.outputs()) { - v.push_back(get_outputs(output->name())); - } - } else { - // Generators with build() method can't have Output<>, hence can't have array outputs - for (const auto &output : p.outputs()) { - v.push_back(std::vector{output}); - } - } -#else - internal_assert(!pi.outputs().empty()); - for (auto *output : pi.outputs()) { - v.push_back(get_outputs(output->name())); - } -#endif - return v; -} - -GeneratorStub::Names GeneratorStub::get_names() const { - generator->ensure_configure_has_been_called(); - auto &pi = generator->param_info(); - Names names; - for (auto *o : pi.generator_params()) { - names.generator_params.push_back(o->name()); - } - for (auto *o : pi.inputs()) { - names.inputs.push_back(o->name()); - } - for (auto *o : pi.outputs()) { - names.outputs.push_back(o->name()); - } - return names; -} - const std::map &get_halide_type_enum_map() { static const std::map halide_type_enum_map{ {"bool", Bool()}, @@ -803,7 +684,9 @@ std::string halide_type_to_c_type(const Type &t) { namespace { -int generate_filter_main_inner(int argc, char **argv, std::ostream &error_output) { +int generate_filter_main_inner(int argc, + char **argv, + const GeneratorFactoryProvider &generator_factory_provider) { static const char kUsage[] = R"INLINE_CODE( gengen [-g GENERATOR_NAME] [-f FUNCTION_NAME] [-o OUTPUT_DIR] [-r RUNTIME_NAME] @@ -836,8 +719,6 @@ gengen find one. Flags across all of the targets that do not affect runtime code generation, such as `no_asserts` and `no_runtime`, are ignored. - -s The name of an autoscheduler to set as the default. - -t Timeout for the Generator to run, in seconds; mainly useful to ensure that bugs and/or degenerate cases don't stall build systems. Defaults to 900 (=15 minutes). Specify 0 to allow ~infinite time. @@ -853,310 +734,467 @@ gengen {"-o", ""}, {"-p", ""}, {"-r", ""}, +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API {"-s", ""}, +#endif {"-t", "900"}, // 15 minutes }; - GeneratorParamsMap generator_args; + + ExecuteGeneratorArgs args; for (int i = 1; i < argc; ++i) { if (argv[i][0] != '-') { std::vector v = split_string(argv[i], "="); - if (v.size() != 2 || v[0].empty() || v[1].empty()) { - error_output << kUsage; - return 1; - } - generator_args[v[0]] = v[1]; - continue; - } - auto it = flags_info.find(argv[i]); - if (it != flags_info.end()) { - if (i + 1 >= argc) { - error_output << kUsage; - return 1; - } + user_assert(v.size() == 2 && !v[0].empty() && !v[1].empty()) << kUsage; + args.generator_params[v[0]] = v[1]; + } else if (auto it = flags_info.find(argv[i]); it != flags_info.end()) { + user_assert(i + 1 < argc) << kUsage; it->second = argv[i + 1]; ++i; continue; + } else { + if (!strcmp(argv[i], "-s")) { +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + user_warning << "HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API is deprecated in Halide 15 " + "(and will be removed in Halide 16).\n"; +#else + user_error << "-s is no longer supported for setting autoscheduler; specify autoschduler.name=NAME instead.\n" + << kUsage; +#endif + } + user_error << "Unknown flag: " << argv[i] << "\n" + << kUsage; } - error_output << "Unknown flag: " << argv[i] << "\n"; - error_output << kUsage; - return 1; } // It's possible that in the future loaded plugins might change // how arguments are parsed, so we handle those first. - for (const auto &lib : split_string(flags_info["-p"], ",")) { - if (!lib.empty()) { - load_plugin(lib); + for (const auto &lib_path : split_string(flags_info["-p"], ",")) { + if (!lib_path.empty()) { + load_plugin(lib_path); } } - if (flags_info["-d"] != "1" && flags_info["-d"] != "0") { - error_output << "-d must be 0 or 1\n"; - error_output << kUsage; - return 1; - } - const int build_gradient_module = flags_info["-d"] == "1"; - - std::string autoscheduler_name = flags_info["-s"]; +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + const auto autoscheduler_name = flags_info["-s"]; if (!autoscheduler_name.empty()) { Pipeline::set_default_autoscheduler_name(autoscheduler_name); } +#else + if (args.generator_params.count("auto_schedule")) { + user_error << "auto_schedule=true is no longer supported for enabling autoscheduling; specify autoscheduler=NAME instead.\n" + << kUsage; + } + if (args.generator_params.count("machine_params")) { + user_error << "machine_params is no longer supported as a GeneratorParam; specify autoscheduler.FIELD=VALUE instead.\n" + << kUsage; + } +#endif + + const auto &d_val = flags_info["-d"]; + user_assert(d_val == "1" || d_val == "0") << "-d must be 0 or 1\n" + << kUsage; - std::string runtime_name = flags_info["-r"]; + const std::vector generator_names = generator_factory_provider.enumerate(); + + const auto create_generator = [&](const std::string &generator_name, const Halide::GeneratorContext &context) -> AbstractGeneratorPtr { + internal_assert(generator_name == args.generator_name); + auto g = generator_factory_provider.create(generator_name, context); + if (!g) { + std::ostringstream o; + o << "Generator not found: " << generator_name << "\n"; + o << "Did you mean:\n"; + for (const auto &n : generator_names) { + o << " " << n << "\n"; + } + user_error << o.str(); + } + return g; + }; - std::vector generator_names = GeneratorRegistry::enumerate(); - if (generator_names.empty() && runtime_name.empty()) { - error_output << "No generators have been registered and not compiling a standalone runtime\n"; - error_output << kUsage; - return 1; + const auto build_target_strings = [](GeneratorParamsMap *gp) { + std::vector target_strings; + if (gp->find("target") != gp->end()) { + target_strings = split_string((*gp)["target"], ","); + gp->erase("target"); + } + return target_strings; + }; + + const auto build_targets = [](const std::vector &target_strings) { + std::vector targets; + for (const auto &s : target_strings) { + targets.emplace_back(s); + } + return targets; + }; + + const auto build_output_types = [&]() { + std::set output_types; + + std::string emit_flags_string = flags_info["-e"]; + // If HL_EXTRA_OUTPUTS is defined, assume it's extra outputs we want to generate + // (usually for temporary debugging purposes) and just tack it on to the -e contents. + std::string extra_outputs = get_env_variable("HL_EXTRA_OUTPUTS"); + if (!extra_outputs.empty()) { + if (!emit_flags_string.empty()) { + emit_flags_string += ","; + } + emit_flags_string += extra_outputs; + } + + const std::vector emit_flags = split_string(emit_flags_string, ","); + + if (emit_flags.empty() || (emit_flags.size() == 1 && emit_flags[0].empty())) { + // If omitted or empty, assume .a and .h and registration.cpp + output_types.insert(OutputFileType::c_header); + output_types.insert(OutputFileType::registration); + output_types.insert(OutputFileType::static_library); + } else { + // Build a reverse lookup table. Allow some legacy aliases on the command line, + // to allow legacy build systems to work more easily. + std::map output_name_to_enum = { + {"cpp", OutputFileType::c_source}, + {"h", OutputFileType::c_header}, + {"html", OutputFileType::stmt_html}, + {"o", OutputFileType::object}, + {"py.c", OutputFileType::python_extension}, + }; + // extensions won't vary across multitarget output + const Target t = args.targets.empty() ? Target() : args.targets[0]; + const std::map output_info = get_output_info(t); + for (const auto &it : output_info) { + output_name_to_enum[it.second.name] = it.first; + } + + for (const std::string &opt : emit_flags) { + auto it = output_name_to_enum.find(opt); + if (it == output_name_to_enum.end()) { + std::ostringstream o; + o << "Unrecognized emit option: " << opt << " is not one of ["; + auto end = output_info.cend(); + auto last = std::prev(end); + for (auto iter = output_info.cbegin(); iter != end; ++iter) { + o << iter->second.name; + if (iter != last) { + o << " "; + } + } + o << "], ignoring.\n"; + o << kUsage; + user_error << o.str(); + } + output_types.insert(it->second); + } + } + return output_types; + }; + + // Always specify target_strings for suffixes: if we omit this, we'll use *canonical* target strings + // for suffixes, but our caller might have passed non-canonical-but-still-legal target strings, + // and if we don't use those, the output filenames might not match what the caller expects. + args.suffixes = build_target_strings(&args.generator_params); + args.targets = build_targets(args.suffixes); + args.output_dir = flags_info["-o"]; + args.output_types = build_output_types(); + args.generator_name = flags_info["-g"]; + args.function_name = flags_info["-f"]; + args.file_base_name = flags_info["-n"]; + args.runtime_name = flags_info["-r"]; + args.build_mode = (d_val == "1") ? ExecuteGeneratorArgs::Gradient : ExecuteGeneratorArgs::Default; + args.create_generator = create_generator; + // args.generator_params is already set + + // Allow quick-n-dirty use of compiler logging via HL_DEBUG_COMPILER_LOGGER env var + const bool do_compiler_logging = args.output_types.count(OutputFileType::compiler_log) || + (get_env_variable("HL_DEBUG_COMPILER_LOGGER") == "1"); + if (do_compiler_logging) { + const bool obfuscate_compiler_logging = get_env_variable("HL_OBFUSCATE_COMPILER_LOGGER") == "1"; + args.compiler_logger_factory = + [obfuscate_compiler_logging, &args](const std::string &function_name, const Target &target) -> std::unique_ptr { + // rebuild generator_args from the map so that they are always canonical + std::string generator_args_string, autoscheduler_name; + std::string sep; + for (const auto &it : args.generator_params) { + std::string quote = it.second.find(' ') != std::string::npos ? "\\\"" : ""; + generator_args_string += sep + it.first + "=" + quote + it.second + quote; + sep = " "; + if (it.first == "autoscheduler") { + autoscheduler_name = it.second; + } + } + std::unique_ptr t(new JSONCompilerLogger( + obfuscate_compiler_logging ? "" : args.generator_name, + obfuscate_compiler_logging ? "" : args.function_name, + obfuscate_compiler_logging ? "" : autoscheduler_name, + obfuscate_compiler_logging ? Target() : target, + obfuscate_compiler_logging ? "" : generator_args_string, + obfuscate_compiler_logging)); + return t; + }; } - std::string generator_name = flags_info["-g"]; - if (generator_name.empty() && runtime_name.empty()) { - // Require either -g or -r to be specified: - // no longer infer the name when only one Generator is registered - error_output << "Either -g or -r must be specified; available Generators are:\n"; + // Do some preflighting here to emit errors that are likely from the command line + // but not necessarily from the API call. + user_assert(!(generator_names.empty() && args.runtime_name.empty())) + << "No generators have been registered and not compiling a standalone runtime\n" + << kUsage; + + if (args.generator_name.empty() && args.runtime_name.empty()) { + // Require at least one of -g or -r to be specified. + std::ostringstream o; + o << "Either -g or -r must be specified; available Generators are:\n"; if (!generator_names.empty()) { for (const auto &name : generator_names) { - error_output << " " << name << "\n"; + o << " " << name << "\n"; } } else { - error_output << " \n"; + o << " \n"; } - return 1; + user_error << o.str(); } - std::string function_name = flags_info["-f"]; - if (function_name.empty()) { - // If -f isn't specified, assume function name = generator name. - function_name = generator_name; - } - std::string output_dir = flags_info["-o"]; - if (output_dir.empty()) { - error_output << "-o must always be specified.\n"; - error_output << kUsage; - return 1; + { + // TODO: should we move the TimeoutMonitor stuff to execute_generator? + // It seems more likely to be useful here. + + struct TimeoutMonitor { + std::atomic generator_finished = false; + std::thread thread; + std::condition_variable cond_var; + std::mutex mutex; + + // Kill the timeout monitor as a destructor to ensure the thread + // gets joined in the event of an exception + ~TimeoutMonitor() { + generator_finished = true; + cond_var.notify_all(); + thread.join(); + } + } monitor; + + const int timeout_in_seconds = std::stoi(flags_info["-t"]); + const auto timeout_time = std::chrono::steady_clock::now() + std::chrono::seconds(timeout_in_seconds); + monitor.thread = std::thread([timeout_time, timeout_in_seconds, &monitor]() { + std::unique_lock lock(monitor.mutex); + + if (timeout_in_seconds <= 0) { + // No watchdog timer, just let it run as long as it likes. + return; + } + while (!monitor.generator_finished) { + auto now = std::chrono::steady_clock::now(); + if (now > timeout_time) { + fprintf(stderr, "Timed out waiting for Generator to complete (%d seconds)!\n", timeout_in_seconds); + fflush(stdout); + fflush(stderr); + exit(1); + } else { + monitor.cond_var.wait_for(lock, timeout_time - now); + } + } + }); + + execute_generator(args); } + return 0; +} - std::string emit_flags_string = flags_info["-e"]; +class GeneratorsFromRegistry : public GeneratorFactoryProvider { +public: + GeneratorsFromRegistry() = default; + ~GeneratorsFromRegistry() override = default; - // If HL_EXTRA_OUTPUTS is defined, assume it's extra outputs we want to generate - // (usually for temporary debugging purposes) and just tack it on to the -e contents. - std::string extra_outputs = get_env_variable("HL_EXTRA_OUTPUTS"); - if (!extra_outputs.empty()) { - if (!emit_flags_string.empty()) { - emit_flags_string += ","; - } - emit_flags_string += extra_outputs; + std::vector enumerate() const override { + return GeneratorRegistry::enumerate(); } - // It's ok to omit "target=" if we are generating *only* a cpp_stub - const std::vector emit_flags = split_string(emit_flags_string, ","); - const bool stub_only = (emit_flags.size() == 1 && emit_flags[0] == "cpp_stub"); - if (!stub_only) { - if (generator_args.find("target") == generator_args.end()) { - error_output << "Target missing\n"; - error_output << kUsage; - return 1; - } + AbstractGeneratorPtr create(const std::string &name, + const Halide::GeneratorContext &context) const override { + return GeneratorRegistry::create(name, context); } +}; + +} // namespace - // it's OK for file_base_name to be empty: filename will be based on function name - std::string file_base_name = flags_info["-n"]; +const GeneratorFactoryProvider &get_registered_generators() { + static GeneratorsFromRegistry g; + return g; +} - auto target_strings = split_string(generator_args["target"].string_value, ","); - std::vector targets; - for (const auto &s : target_strings) { - targets.emplace_back(s); +} // namespace Internal + +Callable create_callable_from_generator(const GeneratorContext &context, + const std::string &name, + const GeneratorParamsMap &generator_params) { + auto g = Internal::get_registered_generators().create(name, context); + user_assert(g != nullptr) << "There is no Generator with the name '" << name << "' currently available."; + g->set_generatorparam_values(generator_params); + return g->compile_to_callable(); +} + +Callable create_callable_from_generator(const Target &target, + const std::string &name, + const GeneratorParamsMap &generator_params) { + return create_callable_from_generator(GeneratorContext(target), name, generator_params); +} + +namespace Internal { + +#ifdef HALIDE_WITH_EXCEPTIONS +int generate_filter_main(int argc, char **argv, const GeneratorFactoryProvider &generator_factory_provider) { + try { + return generate_filter_main_inner(argc, argv, generator_factory_provider); + } catch (::Halide::Error &err) { + user_error << "Unhandled exception: " << err.what() << "\n"; + return -1; + } catch (std::exception &err) { + user_error << "Unhandled exception: " << err.what() << "\n"; + return -1; + } catch (...) { + user_error << "Unhandled exception: (unknown)\n"; + return -1; } +} +#else +int generate_filter_main(int argc, char **argv, const GeneratorFactoryProvider &generator_factory_provider) { + return generate_filter_main_inner(argc, argv, generator_factory_provider); +} +#endif - // extensions won't vary across multitarget output - std::map output_info = get_output_info(targets[0]); +int generate_filter_main(int argc, char **argv) { + return generate_filter_main(argc, argv, GeneratorsFromRegistry()); +} - std::set outputs; - if (emit_flags.empty() || (emit_flags.size() == 1 && emit_flags[0].empty())) { - // If omitted or empty, assume .a and .h and registration.cpp - outputs.insert(OutputFileType::c_header); - outputs.insert(OutputFileType::registration); - outputs.insert(OutputFileType::static_library); - } else { - // Build a reverse lookup table. Allow some legacy aliases on the command line, - // to allow legacy build systems to work more easily. - std::map output_name_to_enum = { - {"cpp", OutputFileType::c_source}, - {"h", OutputFileType::c_header}, - {"html", OutputFileType::stmt_html}, - {"o", OutputFileType::object}, - {"py.c", OutputFileType::python_extension}, - }; - for (const auto &it : output_info) { - output_name_to_enum[it.second.name] = it.first; +void execute_generator(const ExecuteGeneratorArgs &args_in) { + const auto fix_defaults = [](const ExecuteGeneratorArgs &args_in) -> ExecuteGeneratorArgs { + ExecuteGeneratorArgs args = args_in; + if (!args.create_generator) { + args.create_generator = [](const std::string &generator_name, const GeneratorContext &context) -> AbstractGeneratorPtr { + return GeneratorRegistry::create(generator_name, context); + }; } - - for (const std::string &opt : emit_flags) { - auto it = output_name_to_enum.find(opt); - if (it == output_name_to_enum.end()) { - error_output << "Unrecognized emit option: " << opt << " is not one of ["; - auto end = output_info.cend(); - auto last = std::prev(end); - for (auto iter = output_info.cbegin(); iter != end; ++iter) { - error_output << iter->second.name; - if (iter != last) { - error_output << " "; - } - } - error_output << "], ignoring.\n"; - error_output << kUsage; - return 1; - } - outputs.insert(it->second); + if (!args.compiler_logger_factory) { + args.compiler_logger_factory = [](const std::string &, const Target &) -> std::unique_ptr { + return nullptr; + }; } - } + if (args.function_name.empty()) { + args.function_name = args.generator_name; + } + if (args.file_base_name.empty()) { + args.file_base_name = strip_namespaces(args.function_name); + } + return args; + }; - // Allow quick-n-dirty use of compiler logging via HL_DEBUG_COMPILER_LOGGER env var - const bool do_compiler_logging = outputs.count(OutputFileType::compiler_log) || - (get_env_variable("HL_DEBUG_COMPILER_LOGGER") == "1"); + const ExecuteGeneratorArgs args = fix_defaults(args_in); - const bool obfuscate_compiler_logging = get_env_variable("HL_OBFUSCATE_COMPILER_LOGGER") == "1"; + // -------------- Do some sanity checking. + internal_assert(!args.output_dir.empty()); - const CompilerLoggerFactory no_compiler_logger_factory = - [](const std::string &, const Target &) -> std::unique_ptr { - return nullptr; - }; + const bool cpp_stub_only = args.output_types.size() == 1 && + args.output_types.count(OutputFileType::cpp_stub) == 1; + if (!cpp_stub_only) { + // It's ok to leave targets unspecified if we are generating *only* a cpp_stub + internal_assert(!args.targets.empty()); + } - const CompilerLoggerFactory json_compiler_logger_factory = - [&](const std::string &function_name, const Target &target) -> std::unique_ptr { - // rebuild generator_args from the map so that they are always canonical - std::string generator_args_string; - std::string sep; - for (const auto &it : generator_args) { - if (it.first == "target") { - continue; - } - std::string quote = it.second.string_value.find(' ') != std::string::npos ? "\\\"" : ""; - generator_args_string += sep + it.first + "=" + quote + it.second.string_value + quote; - sep = " "; + const auto ensure_valid_name = [](const std::string &s) { + internal_assert(s.empty() || is_valid_name(s)) << "string '" << s << "' is not a valid Generator name."; + }; + const auto ensure_not_pathname = [](const std::string &s) { + for (char c : "/\\") { + internal_assert(s.find(c) == std::string::npos) << "string '" << s << "' must not contain '" << c << "', but saw '" << s << "'"; } - std::unique_ptr t(new JSONCompilerLogger( - obfuscate_compiler_logging ? "" : generator_name, - obfuscate_compiler_logging ? "" : function_name, - obfuscate_compiler_logging ? "" : autoscheduler_name, - obfuscate_compiler_logging ? Target() : target, - obfuscate_compiler_logging ? "" : generator_args_string, - obfuscate_compiler_logging)); - return t; }; - const CompilerLoggerFactory compiler_logger_factory = do_compiler_logging ? - json_compiler_logger_factory : - no_compiler_logger_factory; - - struct TimeoutMonitor { - std::atomic generator_finished = false; - std::thread thread; - std::condition_variable cond_var; - std::mutex mutex; - - // Kill the timeout monitor as a destructor to ensure the thread - // gets joined in the event of an exception - ~TimeoutMonitor() { - generator_finished = true; - cond_var.notify_all(); - thread.join(); - } - } monitor; + // These should be valid Generator names by the rules of is_valid_name() + ensure_valid_name(args.generator_name); - const int timeout_in_seconds = std::stoi(flags_info["-t"]); - const auto timeout_time = std::chrono::steady_clock::now() + std::chrono::seconds(timeout_in_seconds); - monitor.thread = std::thread([timeout_time, timeout_in_seconds, &monitor]() { - std::unique_lock lock(monitor.mutex); + // These should be valid "leaf" filenames, but not full or partial pathnames + ensure_not_pathname(args.runtime_name); + ensure_not_pathname(args.function_name); + ensure_not_pathname(args.file_base_name); + for (const auto &s : args.suffixes) { + ensure_not_pathname(s); + } - if (timeout_in_seconds <= 0) { - // No watchdog timer, just let it run as long as it likes. - return; - } - while (!monitor.generator_finished) { - auto now = std::chrono::steady_clock::now(); - if (now > timeout_time) { - fprintf(stderr, "Timed out waiting for Generator to complete (%d seconds)!\n", timeout_in_seconds); - fflush(stdout); - fflush(stderr); - exit(1); - } else { - monitor.cond_var.wait_for(lock, timeout_time - now); - } - } - }); - - if (!runtime_name.empty()) { - std::string base_path = compute_base_path(output_dir, runtime_name, ""); - - Target gcd_target = targets[0]; - for (size_t i = 1; i < targets.size(); i++) { - if (!gcd_target.get_runtime_compatible_target(targets[i], gcd_target)) { - error_output << "Failed to find compatible runtime target for " - << gcd_target.to_string() - << " and " - << targets[i].to_string() << "\n"; - return -1; - } + // -------------- Process the arguments. + + if (!args.runtime_name.empty()) { + // Runtime always ignores file_base_name + const std::string base_path = args.output_dir + "/" + args.runtime_name; + + Target gcd_target = args.targets[0]; + for (size_t i = 1; i < args.targets.size(); i++) { + internal_assert(gcd_target.get_runtime_compatible_target(args.targets[i], gcd_target)) + << "Failed to find compatible runtime target for " << gcd_target << " and " << args.targets[i]; } - if (targets.size() > 1) { - debug(1) << "Building runtime for computed target: " << gcd_target.to_string() << "\n"; + if (args.targets.size() > 1) { + debug(1) << "Building runtime for computed target: " << gcd_target << "\n"; } - auto output_files = compute_output_files(gcd_target, base_path, outputs); + auto output_files = compute_output_files(gcd_target, base_path, args.output_types); // Runtime doesn't get to participate in the CompilerLogger party compile_standalone_runtime(output_files, gcd_target); } - if (!generator_name.empty()) { - std::string base_path = compute_base_path(output_dir, function_name, file_base_name); - debug(1) << "Generator " << generator_name << " has base_path " << base_path << "\n"; - if (outputs.count(OutputFileType::cpp_stub)) { + if (!args.generator_name.empty()) { + const std::string base_path = args.output_dir + "/" + args.file_base_name; + debug(1) << "Generator " << args.generator_name << " has base_path " << base_path << "\n"; + if (args.output_types.count(OutputFileType::cpp_stub)) { // When generating cpp_stub, we ignore all generator args passed in, and supply a fake Target. // (CompilerLogger is never enabled for cpp_stub, for now anyway.) - auto gen = GeneratorRegistry::create(generator_name, GeneratorContext(Target())); - auto stub_file_path = base_path + output_info[OutputFileType::cpp_stub].extension; - gen->emit_cpp_stub(stub_file_path); + const Target fake_target = Target(); + auto gen = args.create_generator(args.generator_name, GeneratorContext(fake_target)); + auto output_files = compute_output_files(fake_target, base_path, args.output_types); + gen->emit_cpp_stub(output_files[OutputFileType::cpp_stub]); } // Don't bother with this if we're just emitting a cpp_stub. - if (!stub_only) { - auto output_files = compute_output_files(targets[0], base_path, outputs); - auto module_factory = [&generator_name, &generator_args, build_gradient_module](const std::string &name, const Target &target) -> Module { - auto sub_generator_args = generator_args; - sub_generator_args.erase("target"); - // Must re-create each time since each instance will have a different Target. - auto gen = GeneratorRegistry::create(generator_name, GeneratorContext(target)); - gen->set_generator_param_values(sub_generator_args); - return build_gradient_module ? gen->build_gradient_module(name) : gen->build_module(name); + if (!cpp_stub_only) { + auto output_files = compute_output_files(args.targets[0], base_path, args.output_types); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + const auto get_gp = [&](const std::string &key) { + auto it = args.generator_params.find(key); + return it != args.generator_params.end() ? it->second : ""; }; - compile_multitarget(function_name, output_files, targets, target_strings, module_factory, compiler_logger_factory); - } - } - - return 0; -} - -} // namespace - -#ifdef HALIDE_WITH_EXCEPTIONS -int generate_filter_main(int argc, char **argv, std::ostream &error_output) { - try { - return generate_filter_main_inner(argc, argv, error_output); - } catch (std::runtime_error &err) { - error_output << "Unhandled exception: " << err.what() << "\n"; - return -1; - } -} + const auto auto_schedule_string = get_gp("auto_schedule"); + const auto machine_params_string = get_gp("machine_params"); + const bool auto_schedule = auto_schedule_string == "true" || auto_schedule_string == "True"; + const MachineParams machine_params = !machine_params_string.empty() ? MachineParams(machine_params_string) : MachineParams::generic(); +#endif + auto module_factory = [&](const std::string &function_name, const Target &target) -> Module { + // Must re-create each time since each instance will have a different Target. +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + auto gen = args.create_generator(args.generator_name, GeneratorContext(target, auto_schedule, machine_params)); + for (const auto &kv : args.generator_params) { + if (kv.first == "target" || + kv.first == "auto_schedule" || + kv.first == "machine_params") { + continue; + } + gen->set_generatorparam_value(kv.first, kv.second); + } #else -int generate_filter_main(int argc, char **argv, std::ostream &error_output) { - return generate_filter_main_inner(argc, argv, error_output); -} + auto gen = args.create_generator(args.generator_name, GeneratorContext(target)); + for (const auto &kv : args.generator_params) { + if (kv.first == "target") { + continue; + } + gen->set_generatorparam_value(kv.first, kv.second); + } #endif + return args.build_mode == ExecuteGeneratorArgs::Gradient ? + gen->build_gradient_module(function_name) : + gen->build_module(function_name); + }; + compile_multitarget(args.function_name, output_files, args.targets, args.suffixes, module_factory, args.compiler_logger_factory); + } + } +} GeneratorParamBase::GeneratorParamBase(const std::string &name) : name_(name) { @@ -1170,18 +1208,20 @@ GeneratorParamBase::~GeneratorParamBase() { void GeneratorParamBase::check_value_readable() const { // These are always readable. +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API if (name() == "target" || name() == "auto_schedule" || name() == "machine_params") { return; } -#ifdef HALIDE_ALLOW_GENERATOR_BUILD_METHOD - user_assert(generator && generator->phase >= GeneratorBase::ConfigureCalled) - << "The GeneratorParam \"" << name() << "\" cannot be read before build() or configure()/generate() is called.\n"; #else + if (name() == "target" || + name() == "autoscheduler") { + return; + } +#endif user_assert(generator && generator->phase >= GeneratorBase::ConfigureCalled) << "The GeneratorParam \"" << name() << "\" cannot be read before configure()/generate() is called.\n"; -#endif } void GeneratorParamBase::check_value_writable() const { @@ -1189,19 +1229,58 @@ void GeneratorParamBase::check_value_writable() const { if (!generator) { return; } -#ifdef HALIDE_ALLOW_GENERATOR_BUILD_METHOD - user_assert(generator->phase < GeneratorBase::GenerateCalled) - << "The GeneratorParam \"" << name() << "\" cannot be written after build() or generate() is called.\n"; -#else user_assert(generator->phase < GeneratorBase::GenerateCalled) << "The GeneratorParam \"" << name() << "\" cannot be written after generate() is called.\n"; -#endif } void GeneratorParamBase::fail_wrong_type(const char *type) { user_error << "The GeneratorParam \"" << name() << "\" cannot be set with a value of type " << type << ".\n"; } +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API +// nothing +#else +GeneratorParam_AutoSchedulerParams::GeneratorParam_AutoSchedulerParams() + : GeneratorParamImpl("autoscheduler", {}) { +} + +void GeneratorParam_AutoSchedulerParams::set_from_string(const std::string &new_value_string) { + internal_error << "This method should never be called."; +} + +std::string GeneratorParam_AutoSchedulerParams::get_default_value() const { + internal_error << "This method should never be called."; + return ""; +} + +std::string GeneratorParam_AutoSchedulerParams::call_to_string(const std::string &v) const { + internal_error << "This method should never be called."; + return ""; +} + +std::string GeneratorParam_AutoSchedulerParams::get_c_type() const { + internal_error << "This method should never be called."; + return ""; +} + +bool GeneratorParam_AutoSchedulerParams::try_set(const std::string &key, const std::string &value) { + const auto &n = this->name(); + if (key == n) { + user_assert(this->value_.name.empty()) << "The GeneratorParam " << key << " cannot be set more than once.\n"; + this->value_.name = value; + return true; + } else if (starts_with(key, n + ".")) { + const auto sub_key = key.substr(n.size() + 1); + user_assert(this->value_.extra.count(sub_key) == 0) << "The GeneratorParam " << key << " cannot be set more than once.\n"; + this->value_.extra[sub_key] = value; + return true; + } else { + return false; + } +} + +#endif + /* static */ GeneratorRegistry &GeneratorRegistry::get_registry() { static GeneratorRegistry *registry = new GeneratorRegistry; @@ -1229,22 +1308,18 @@ void GeneratorRegistry::unregister_factory(const std::string &name) { } /* static */ -std::unique_ptr GeneratorRegistry::create(const std::string &name, - const GeneratorContext &context) { +AbstractGeneratorPtr GeneratorRegistry::create(const std::string &name, + const GeneratorContext &context) { GeneratorRegistry ®istry = get_registry(); std::lock_guard lock(registry.mutex); auto it = registry.factories.find(name); if (it == registry.factories.end()) { - std::ostringstream o; - o << "Generator not found: " << name << "\n"; - o << "Did you mean:\n"; - for (const auto &n : registry.factories) { - o << " " << n.first << "\n"; - } - user_error << o.str(); + return nullptr; } - std::unique_ptr g = it->second(context); - internal_assert(g != nullptr); + GeneratorFactory f = it->second; + AbstractGeneratorPtr g = f(context); + // Do not assert! Just return nullptr. + // internal_assert(g != nullptr); return g; } @@ -1278,10 +1353,10 @@ GeneratorParamInfo::GeneratorParamInfo(GeneratorBase *generator, const size_t si const std::string &n = gio->name(); const std::string &gn = generator->generator_registered_name; - owned_synthetic_params.push_back(GeneratorParam_Synthetic::make(generator, gn, n + ".type", *gio, SyntheticParamType::Type, gio->types_defined())); + owned_synthetic_params.push_back(GeneratorParam_Synthetic::make(generator, gn, n + ".type", *gio, SyntheticParamType::Type, gio->gio_types_defined())); filter_generator_params.push_back(owned_synthetic_params.back().get()); - if (gio->kind() != IOKind::Scalar) { + if (gio->kind() != ArgInfoKind::Scalar) { owned_synthetic_params.push_back(GeneratorParam_Synthetic::make(generator, gn, n + ".dim", *gio, SyntheticParamType::Dim, gio->dims_defined())); filter_generator_params.push_back(owned_synthetic_params.back().get()); } @@ -1342,65 +1417,46 @@ GeneratorParamInfo &GeneratorBase::param_info() { return *param_info_ptr; } -std::vector GeneratorBase::get_outputs(const std::string &n) { - check_min_phase(GenerateCalled); - auto *output = find_output_by_name(n); - // Call for the side-effect of asserting if the value isn't defined. - (void)output->array_size(); - for (const auto &f : output->funcs()) { - user_assert(f.defined()) << "Output " << n << " was not fully defined.\n"; - } - return output->funcs(); +GeneratorInputBase *GeneratorBase::find_input_by_name(const std::string &name) { + auto *t = GeneratorBase::find_by_name(name, param_info().inputs()); + internal_assert(t != nullptr) << "Input " << name << " not found."; + return t; } -// Find output by name. If not found, assert-fail. Never returns null. GeneratorOutputBase *GeneratorBase::find_output_by_name(const std::string &name) { - // There usually are very few outputs, so a linear search is fine - GeneratorParamInfo &pi = param_info(); - for (GeneratorOutputBase *output : pi.outputs()) { - if (output->name() == name) { - return output; - } - } - internal_error << "Output " << name << " not found."; - return nullptr; // not reached -} - -void GeneratorBase::set_generator_param_values(const GeneratorParamsMap ¶ms) { - GeneratorParamInfo &pi = param_info(); - - std::unordered_map generator_params_by_name; - for (auto *g : pi.generator_params()) { - generator_params_by_name[g->name()] = g; - } - - for (const auto &key_value : params) { - auto gp = generator_params_by_name.find(key_value.first); - user_assert(gp != generator_params_by_name.end()) - << "Generator " << generator_registered_name << " has no GeneratorParam named: " << key_value.first << "\n"; - if (gp->second->is_looplevel_param()) { - if (!key_value.second.string_value.empty()) { - gp->second->set_from_string(key_value.second.string_value); - } else { - gp->second->set(key_value.second.loop_level); - } - } else { - gp->second->set_from_string(key_value.second.string_value); - } - } + auto *t = GeneratorBase::find_by_name(name, param_info().outputs()); + internal_assert(t != nullptr) << "Output " << name << " not found."; + return t; } GeneratorContext GeneratorBase::context() const { - return GeneratorContext(target, auto_schedule, machine_params, externs_map, value_tracker); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE + return GeneratorContext(target, auto_schedule, machine_params, externs_map); +#else + return GeneratorContext(target, auto_schedule, machine_params); +#endif +#else +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE + return GeneratorContext(target, autoscheduler_.value(), externs_map); +#else + return GeneratorContext(target, autoscheduler_.value()); +#endif +#endif } void GeneratorBase::init_from_context(const Halide::GeneratorContext &context) { target.set(context.target_); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API auto_schedule.set(context.auto_schedule_); machine_params.set(context.machine_params_); +#else + autoscheduler_.set(context.autoscheduler_params_); +#endif +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE externs_map = context.externs_map_; - value_tracker = context.value_tracker_; +#endif // pre-emptively build our param_info now internal_assert(param_info_ptr == nullptr); @@ -1416,8 +1472,8 @@ void GeneratorBase::set_generator_names(const std::string ®istered_name, cons } void GeneratorBase::set_inputs_vector(const std::vector> &inputs) { + ensure_configure_has_been_called(); advance_phase(InputsSet); - internal_assert(!inputs_set) << "set_inputs_vector() must be called at most once per Generator instance.\n"; GeneratorParamInfo &pi = param_info(); user_assert(inputs.size() == pi.inputs().size()) << "Expected exactly " << pi.inputs().size() @@ -1425,36 +1481,6 @@ void GeneratorBase::set_inputs_vector(const std::vector> for (size_t i = 0; i < pi.inputs().size(); ++i) { pi.inputs()[i]->set_inputs(inputs[i]); } - inputs_set = true; -} - -void GeneratorBase::track_parameter_values(bool include_outputs) { - GeneratorParamInfo &pi = param_info(); - for (auto *input : pi.inputs()) { - if (input->kind() == IOKind::Buffer) { - internal_assert(!input->parameters_.empty()); - for (auto &p : input->parameters_) { - // This must use p.name(), *not* input->name() - value_tracker->track_values(p.name(), parameter_constraints(p)); - } - } - } - if (include_outputs) { - for (auto *output : pi.outputs()) { - if (output->kind() == IOKind::Buffer) { - internal_assert(!output->funcs().empty()); - for (const auto &f : output->funcs()) { - user_assert(f.defined()) << "Output " << output->name() << " is not fully defined."; - auto output_buffers = f.output_buffers(); - for (auto &o : output_buffers) { - Parameter p = o.parameter(); - // This must use p.name(), *not* output->name() - value_tracker->track_values(p.name(), parameter_constraints(p)); - } - } - } - } - } } void GeneratorBase::check_min_phase(Phase expected_phase) const { @@ -1474,7 +1500,7 @@ void GeneratorBase::advance_phase(Phase new_phase) { internal_assert(phase == Created); break; case InputsSet: - internal_assert(phase == Created || phase == ConfigureCalled); + internal_assert(phase == Created || phase == ConfigureCalled || phase == InputsSet); break; case GenerateCalled: // It's OK to advance directly to GenerateCalled. @@ -1507,50 +1533,23 @@ void GeneratorBase::pre_generate() { user_assert(!pi.outputs().empty()) << "Must use Output<> with generate() method."; user_assert(get_target() != Target()) << "The Generator target has not been set."; - if (!inputs_set) { - for (auto *input : pi.inputs()) { - input->init_internals(); - } - inputs_set = true; + for (auto *input : pi.inputs()) { + input->init_internals(); } for (auto *output : pi.outputs()) { output->init_internals(); } - track_parameter_values(false); } void GeneratorBase::post_generate() { - track_parameter_values(true); } void GeneratorBase::pre_schedule() { advance_phase(ScheduleCalled); - track_parameter_values(true); } void GeneratorBase::post_schedule() { - track_parameter_values(true); -} - -#ifdef HALIDE_ALLOW_GENERATOR_BUILD_METHOD -void GeneratorBase::pre_build() { - advance_phase(GenerateCalled); - advance_phase(ScheduleCalled); - GeneratorParamInfo &pi = param_info(); - user_assert(pi.outputs().empty()) << "May not use build() method with Output<>."; - if (!inputs_set) { - for (auto *input : pi.inputs()) { - input->init_internals(); - } - inputs_set = true; - } - track_parameter_values(false); -} - -void GeneratorBase::post_build() { - track_parameter_values(true); } -#endif Pipeline GeneratorBase::get_pipeline() { check_min_phase(GenerateCalled); @@ -1566,13 +1565,13 @@ Pipeline GeneratorBase::get_pipeline() { << "\" requires dimensions=" << output->dims() << " but was defined as dimensions=" << f.dimensions() << ".\n"; } - if (output->types_defined()) { - user_assert((int)f.outputs() == (int)output->types().size()) << "Output \"" << f.name() - << "\" requires a Tuple of size " << output->types().size() - << " but was defined as Tuple of size " << f.outputs() << ".\n"; - for (size_t i = 0; i < f.output_types().size(); ++i) { - Type expected = output->types().at(i); - Type actual = f.output_types()[i]; + if (output->gio_types_defined()) { + user_assert((int)f.outputs() == (int)output->gio_types().size()) << "Output \"" << f.name() + << "\" requires a Tuple of size " << output->gio_types().size() + << " but was defined as Tuple of size " << f.outputs() << ".\n"; + for (size_t i = 0; i < f.types().size(); ++i) { + Type expected = output->gio_types().at(i); + Type actual = f.types()[i]; user_assert(expected == actual) << "Output \"" << f.name() << "\" requires type " << expected << " but was defined as type " << actual << ".\n"; @@ -1586,211 +1585,150 @@ Pipeline GeneratorBase::get_pipeline() { return pipeline; } -Module GeneratorBase::build_module(const std::string &function_name, - const LinkageType linkage_type) { - AutoSchedulerResults auto_schedule_results; - ensure_configure_has_been_called(); - Pipeline pipeline = build_pipeline(); - if (get_auto_schedule()) { - auto_schedule_results = pipeline.auto_schedule(get_target(), get_machine_params()); - } - - const GeneratorParamInfo &pi = param_info(); - std::vector filter_arguments; - for (const auto *input : pi.inputs()) { - for (const auto &p : input->parameters_) { - filter_arguments.push_back(to_argument(p)); - } - } - - Module result = pipeline.compile_to_module(filter_arguments, function_name, get_target(), linkage_type); - std::shared_ptr externs_map = get_externs_map(); - for (const auto &map_entry : *externs_map) { - result.append(map_entry.second); - } - - for (const auto *output : pi.outputs()) { - for (size_t i = 0; i < output->funcs().size(); ++i) { - auto from = output->funcs()[i].name(); - auto to = output->array_name(i); - size_t tuple_size = output->types_defined() ? output->types().size() : 1; - for (size_t t = 0; t < tuple_size; ++t) { - std::string suffix = (tuple_size > 1) ? ("." + std::to_string(t)) : ""; - result.remap_metadata_name(from + suffix, to + suffix); - } - } - } - - result.set_auto_scheduler_results(auto_schedule_results); - - return result; +void GeneratorBase::check_scheduled(const char *m) const { + check_min_phase(ScheduleCalled); } -Module GeneratorBase::build_gradient_module(const std::string &function_name) { - constexpr int DBG = 1; +void GeneratorBase::check_input_is_singular(Internal::GeneratorInputBase *in) { + user_assert(!in->is_array()) + << "Input " << in->name() << " is an array, and must be set with a vector type."; +} - // I doubt these ever need customizing; if they do, we can make them arguments to this function. - const std::string grad_input_pattern = "_grad_loss_for_$OUT$"; - const std::string grad_output_pattern = "_grad_loss_$OUT$_wrt_$IN$"; - const LinkageType linkage_type = LinkageType::ExternalPlusMetadata; +void GeneratorBase::check_input_is_array(Internal::GeneratorInputBase *in) { + user_assert(in->is_array()) + << "Input " << in->name() << " is not an array, and must not be set with a vector type."; +} - user_assert(!function_name.empty()) << "build_gradient_module(): function_name cannot be empty\n"; +void GeneratorBase::check_input_kind(Internal::GeneratorInputBase *in, Internal::ArgInfoKind kind) { + user_assert(in->kind() == kind) + << "Input " << in->name() << " cannot be set with the type specified."; +} - ensure_configure_has_been_called(); - Pipeline original_pipeline = build_pipeline(); - std::vector original_outputs = original_pipeline.outputs(); - - // Construct the adjoint pipeline, which has: - // - All the same inputs as the original, in the same order - // - Followed by one grad-input for each original output - // - Followed by one output for each unique pairing of original-output + original-input. - - const GeneratorParamInfo &pi = param_info(); - - // Even though propagate_adjoints() supports Funcs-of-Tuples just fine, - // we aren't going to support them here (yet); AFAICT, neither PyTorch nor - // TF support Tensors with Tuples-as-values, so we'd have to split the - // tuples up into separate Halide inputs and outputs anyway; since Generator - // doesn't support Tuple-valued Inputs at all, and Tuple-valued Outputs - // are quite rare, we're going to just fail up front, with the assumption - // that the coder will explicitly adapt their code as needed. (Note that - // support for Tupled outputs could be added with some effort, so if this - // is somehow deemed critical, go for it) - for (const auto *input : pi.inputs()) { - const size_t tuple_size = input->types_defined() ? input->types().size() : 1; - // Note: this should never happen - internal_assert(tuple_size == 1) << "Tuple Inputs are not yet supported by build_gradient_module()"; - } - for (const auto *output : pi.outputs()) { - const size_t tuple_size = output->types_defined() ? output->types().size() : 1; - internal_assert(tuple_size == 1) << "Tuple Outputs are not yet supported by build_gradient_module"; - } - - std::vector gradient_inputs; - - // First: the original inputs. Note that scalar inputs remain scalar, - // rather being promoted into zero-dimensional buffers. - for (const auto *input : pi.inputs()) { - // There can be multiple Funcs/Parameters per input if the - // input is an Array. - if (input->is_array()) { - internal_assert(input->parameters_.size() == input->funcs_.size()); - } - for (const auto &p : input->parameters_) { - gradient_inputs.push_back(to_argument(p)); - debug(DBG) << " gradient copied input is: " << gradient_inputs.back().name << "\n"; - } +void GeneratorBase::set_generatorparam_value(const std::string &name, const std::string &value) { +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + if (name == "target" || + name == "auto_schedule" || + name == "machine_params") { + user_error + << "The GeneratorParam named " << name << " cannot be set by set_generatorparam_value().\n"; } +#else + user_assert(name != "target") << "The GeneratorParam named " << name << " cannot be set by set_generatorparam_value().\n"; + if (autoscheduler_.try_set(name, value)) { + return; + } +#endif - // Next: add a grad-input for each *original* output; these will - // be the same shape as the output (so we should copy estimates from - // those outputs onto these estimates). - // - If an output is an Array, we'll have a separate input for each array element. - - std::vector d_output_imageparams; - for (const auto *output : pi.outputs()) { - for (size_t i = 0; i < output->funcs().size(); ++i) { - const Func &f = output->funcs()[i]; - const std::string output_name = output->array_name(i); - // output_name is something like "funcname_i" - const std::string grad_in_name = replace_all(grad_input_pattern, "$OUT$", output_name); - // TODO(srj): does it make sense for gradient to be a non-float type? - // For now, assume it's always float32 (unless the output is already some float). - const Type grad_in_type = output->type().is_float() ? output->type() : Float(32); - const int grad_in_dimensions = f.dimensions(); - const ArgumentEstimates grad_in_estimates = f.output_buffer().parameter().get_argument_estimates(); - internal_assert((int)grad_in_estimates.buffer_estimates.size() == grad_in_dimensions); - - ImageParam d_im(grad_in_type, grad_in_dimensions, grad_in_name); - for (int d = 0; d < grad_in_dimensions; d++) { - d_im.parameter().set_min_constraint_estimate(d, grad_in_estimates.buffer_estimates[i].min); - d_im.parameter().set_extent_constraint_estimate(d, grad_in_estimates.buffer_estimates[i].extent); - } - d_output_imageparams.push_back(d_im); - gradient_inputs.push_back(to_argument(d_im.parameter())); + GeneratorParamInfo &pi = param_info(); - debug(DBG) << " gradient synthesized input is: " << gradient_inputs.back().name << "\n"; + for (auto *g : pi.generator_params()) { + if (g->name() != name) { + continue; } + g->set_from_string(value); + return; } + user_error + << "Generator " << generator_registered_name << " has no GeneratorParam named: " << name << "\n"; +} - // Finally: define the output Func(s), one for each unique output/input pair. - // Note that original_outputs.size() != pi.outputs().size() if any outputs are arrays. - internal_assert(original_outputs.size() == d_output_imageparams.size()); - std::vector gradient_outputs; - for (size_t i = 0; i < original_outputs.size(); ++i) { - const Func &original_output = original_outputs.at(i); - const ImageParam &d_output = d_output_imageparams.at(i); - Region bounds; - for (int i = 0; i < d_output.dimensions(); i++) { - bounds.emplace_back(d_output.dim(i).min(), d_output.dim(i).extent()); +void GeneratorBase::set_generatorparam_value(const std::string &name, const LoopLevel &value) { + GeneratorParamInfo &pi = param_info(); + for (auto *g : pi.generator_params()) { + if (g->name() != name) { + continue; } - Func adjoint_func = BoundaryConditions::constant_exterior(d_output, make_zero(d_output.type())); - Derivative d = propagate_adjoints(original_output, adjoint_func, bounds); - - const std::string &output_name = original_output.name(); - for (const auto *input : pi.inputs()) { - for (size_t i = 0; i < input->funcs_.size(); ++i) { - const std::string input_name = input->array_name(i); - const auto &f = input->funcs_[i]; - const auto &p = input->parameters_[i]; + user_assert(g->is_looplevel_param()) << "GeneratorParam " << name << " is not a LoopLevel and cannot be set this way."; + g->set(value); + return; + } + user_error + << "Generator " << generator_registered_name << " has no GeneratorParam named: " << name << "\n"; +} - Func d_f = d(f); +std::string GeneratorBase::name() { + return generator_registered_name; +} - std::string grad_out_name = replace_all(replace_all(grad_output_pattern, "$OUT$", output_name), "$IN$", input_name); - if (!d_f.defined()) { - grad_out_name = "_dummy" + grad_out_name; - } +std::vector GeneratorBase::arginfos() { + ensure_configure_has_been_called(); + std::vector args; + args.reserve(param_info().inputs().size() + param_info().outputs().size()); + GeneratorBase::get_arguments(args, ArgInfoDirection::Input, param_info().inputs()); + GeneratorBase::get_arguments(args, ArgInfoDirection::Output, param_info().outputs()); + return args; +} - Func d_out_wrt_in(grad_out_name); - if (d_f.defined()) { - d_out_wrt_in(Halide::_) = d_f(Halide::_); - } else { - debug(DBG) << " No Derivative found for output " << output_name << " wrt input " << input_name << "\n"; - // If there was no Derivative found, don't skip the output; - // just replace with a dummy Func that is all zeros. This ensures - // that the signature of the Pipeline we produce is always predictable. - std::vector vars; - for (int i = 0; i < d_output.dimensions(); i++) { - vars.push_back(Var::implicit(i)); - } - d_out_wrt_in(vars) = make_zero(d_output.type()); - } +std::vector GeneratorBase::input_parameter(const std::string &name) { + auto *input = find_input_by_name(name); - d_out_wrt_in.set_estimates(p.get_argument_estimates().buffer_estimates); + const size_t params_size = input->parameters_.size(); + const bool is_buffer = input->kind() != ArgInfoKind::Scalar; + if (is_buffer) { + internal_assert(input->exprs_.empty() && input->funcs_.size() == params_size); + } else { + internal_assert(input->funcs_.empty() && input->exprs_.size() == params_size); + } - // Useful for debugging; ordinarily better to leave out - // debug(0) << "\n\n" - // << "output:\n" << FuncWithDependencies(original_output) << "\n" - // << "d_output:\n" << FuncWithDependencies(adjoint_func) << "\n" - // << "input:\n" << FuncWithDependencies(f) << "\n" - // << "d_out_wrt_in:\n" << FuncWithDependencies(d_out_wrt_in) << "\n"; + std::vector params; + params.reserve(params_size); - gradient_outputs.push_back(d_out_wrt_in); - debug(DBG) << " gradient output is: " << d_out_wrt_in.name() << "\n"; - } - } + for (size_t i = 0; i < params_size; ++i) { + const auto &p = input->parameters_[i]; + internal_assert(p.is_buffer() == is_buffer); + const auto name = input->array_name(i); + internal_assert(p.name() == name) << "input name was " << p.name() << " expected " << name; + const int expected_dimensions = is_buffer ? input->funcs_[i].dimensions() : 0; + internal_assert(p.dimensions() == expected_dimensions) << "input dimensions was " << p.dimensions() << " expected " << expected_dimensions; + internal_assert(p.type() == input->gio_type()) << "input type was " << p.type() << " expected " << input->gio_type(); + params.push_back(p); } + return params; +} - Pipeline grad_pipeline = Pipeline(gradient_outputs); - - AutoSchedulerResults auto_schedule_results; - if (get_auto_schedule()) { - auto_schedule_results = grad_pipeline.auto_schedule(get_target(), get_machine_params()); - } else { - user_warning << "Autoscheduling is not enabled in build_gradient_module(), so the resulting " - "gradient module will be unscheduled; this is very unlikely to be what you want.\n"; +std::vector GeneratorBase::output_func(const std::string &n) { + check_min_phase(GenerateCalled); + auto *output = find_output_by_name(n); + // Call for the side-effect of asserting if the value isn't defined. + (void)output->array_size(); + for (const auto &f : output->funcs()) { + user_assert(f.defined()) << "Output " << n << " was not fully defined.\n"; } + return output->funcs(); +} + +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE +ExternsMap GeneratorBase::external_code_map() { + // get_externs_map() returns a std::shared_ptr + return *get_externs_map(); +} +#endif - Module result = grad_pipeline.compile_to_module(gradient_inputs, function_name, get_target(), linkage_type); - user_assert(get_externs_map()->empty()) - << "Building a gradient-descent module for a Generator with ExternalCode is not supported.\n"; +void GeneratorBase::bind_input(const std::string &name, const std::vector &v) { + ensure_configure_has_been_called(); + advance_phase(InputsSet); + std::vector si; + std::copy(v.begin(), v.end(), std::back_inserter(si)); + find_input_by_name(name)->set_inputs(si); +} - result.set_auto_scheduler_results(auto_schedule_results); +void GeneratorBase::bind_input(const std::string &name, const std::vector &v) { + ensure_configure_has_been_called(); + advance_phase(InputsSet); + std::vector si; + std::copy(v.begin(), v.end(), std::back_inserter(si)); + find_input_by_name(name)->set_inputs(si); +} - return result; +void GeneratorBase::bind_input(const std::string &name, const std::vector &v) { + ensure_configure_has_been_called(); + advance_phase(InputsSet); + std::vector si; + std::copy(v.begin(), v.end(), std::back_inserter(si)); + find_input_by_name(name)->set_inputs(si); } -void GeneratorBase::emit_cpp_stub(const std::string &stub_file_path) { +bool GeneratorBase::emit_cpp_stub(const std::string &stub_file_path) { user_assert(!generator_registered_name.empty() && !generator_stub_name.empty()) << "Generator has no name.\n"; // Make sure we call configure() so that extra inputs/outputs are added as necessary. ensure_configure_has_been_called(); @@ -1801,30 +1739,12 @@ void GeneratorBase::emit_cpp_stub(const std::string &stub_file_path) { std::ofstream file(stub_file_path); StubEmitter emit(file, generator_registered_name, generator_stub_name, pi.generator_params(), pi.inputs(), pi.outputs()); emit.emit(); -} - -void GeneratorBase::check_scheduled(const char *m) const { - check_min_phase(ScheduleCalled); -} - -void GeneratorBase::check_input_is_singular(Internal::GeneratorInputBase *in) { - user_assert(!in->is_array()) - << "Input " << in->name() << " is an array, and must be set with a vector type."; -} - -void GeneratorBase::check_input_is_array(Internal::GeneratorInputBase *in) { - user_assert(in->is_array()) - << "Input " << in->name() << " is not an array, and must not be set with a vector type."; -} - -void GeneratorBase::check_input_kind(Internal::GeneratorInputBase *in, Internal::IOKind kind) { - user_assert(in->kind() == kind) - << "Input " << in->name() << " cannot be set with the type specified."; + return true; } GIOBase::GIOBase(size_t array_size, const std::string &name, - IOKind kind, + ArgInfoKind kind, const std::vector &types, int dims) : array_size_(array_size), name_(name), kind_(kind), types_(types), dims_(dims) { @@ -1849,38 +1769,38 @@ const std::string &GIOBase::name() const { return name_; } -IOKind GIOBase::kind() const { +ArgInfoKind GIOBase::kind() const { return kind_; } -bool GIOBase::types_defined() const { +bool GIOBase::gio_types_defined() const { return !types_.empty(); } -const std::vector &GIOBase::types() const { +const std::vector &GIOBase::gio_types() const { // If types aren't defined, but we have one Func that is, // we probably just set an Output and should propagate the types. - if (!types_defined()) { + if (!gio_types_defined()) { // use funcs_, not funcs(): the latter could give a much-less-helpful error message // in this case. const auto &f = funcs_; if (f.size() == 1 && f.at(0).defined()) { - check_matching_types(f.at(0).output_types()); + check_matching_types(f.at(0).types()); } } - user_assert(types_defined()) << "Type is not defined for " << input_or_output() << " '" << name() << "'; you may need to specify '" << name() << ".type' as a GeneratorParam, or call set_type() from the configure() method.\n"; + user_assert(gio_types_defined()) << "Type is not defined for " << input_or_output() << " '" << name() << "'; you may need to specify '" << name() << ".type' as a GeneratorParam, or call set_type() from the configure() method.\n"; return types_; } -Type GIOBase::type() const { - const auto &t = types(); +Type GIOBase::gio_type() const { + const auto &t = gio_types(); internal_assert(t.size() == 1) << "Expected types_.size() == 1, saw " << t.size() << " for " << name() << "\n"; return t.at(0); } void GIOBase::set_type(const Type &type) { generator->check_exact_phase(GeneratorBase::ConfigureCalled); - user_assert(!types_defined()) << "set_type() may only be called on an Input or Output that has no type specified."; + user_assert(!gio_types_defined()) << "set_type() may only be called on an Input or Output that has no type specified."; types_ = {type}; } @@ -1928,7 +1848,7 @@ const std::vector &GIOBase::exprs() const { void GIOBase::verify_internals() { user_assert(dims_ >= 0) << "Generator Input/Output Dimensions must have positive values"; - if (kind() != IOKind::Scalar) { + if (kind() != ArgInfoKind::Scalar) { for (const Func &f : funcs()) { user_assert(f.defined()) << "Input/Output " << name() << " is not defined.\n"; user_assert(f.dimensions() == dims()) @@ -1939,20 +1859,20 @@ void GIOBase::verify_internals() { << "Expected outputs() == " << 1 << " but got " << f.outputs() << " for " << name() << "\n"; - user_assert(f.output_types().size() == 1) - << "Expected output_types().size() == " << 1 + user_assert(f.types().size() == 1) + << "Expected types().size() == " << 1 << " but got " << f.outputs() << " for " << name() << "\n"; - user_assert(f.output_types()[0] == type()) - << "Expected type " << type() - << " but got " << f.output_types()[0] + user_assert(f.types()[0] == gio_type()) + << "Expected type " << gio_type() + << " but got " << f.types()[0] << " for " << name() << "\n"; } } else { for (const Expr &e : exprs()) { user_assert(e.defined()) << "Input/Ouput " << name() << " is not defined.\n"; - user_assert(e.type() == type()) - << "Expected type " << type() + user_assert(e.type() == gio_type()) + << "Expected type " << gio_type() << " but got " << e.type() << " for " << name() << "\n"; } @@ -1970,10 +1890,10 @@ std::string GIOBase::array_name(size_t i) const { // If our type(s) are defined, ensure it matches the ones passed in, asserting if not. // If our type(s) are not defined, just set to the ones passed in. void GIOBase::check_matching_types(const std::vector &t) const { - if (types_defined()) { - user_assert(types().size() == t.size()) << "Type mismatch for " << name() << ": expected " << types().size() << " types but saw " << t.size(); + if (gio_types_defined()) { + user_assert(gio_types().size() == t.size()) << "Type mismatch for " << name() << ": expected " << gio_types().size() << " types but saw " << t.size(); for (size_t i = 0; i < t.size(); ++i) { - user_assert(types().at(i) == t.at(i)) << "Type mismatch for " << name() << ": expected " << types().at(i) << " saw " << t.at(i); + user_assert(gio_types().at(i) == t.at(i)) << "Type mismatch for " << name() << ": expected " << gio_types().at(i) << " saw " << t.at(i); } } else { types_ = t; @@ -1985,13 +1905,8 @@ void GIOBase::check_gio_access() const { if (!generator) { return; } -#ifdef HALIDE_ALLOW_GENERATOR_BUILD_METHOD - user_assert(generator->phase > GeneratorBase::InputsSet) - << "The " << input_or_output() << " \"" << name() << "\" cannot be examined before build() or generate() is called.\n"; -#else user_assert(generator->phase > GeneratorBase::InputsSet) << "The " << input_or_output() << " \"" << name() << "\" cannot be examined before generate() is called.\n"; -#endif } // If our dims are defined, ensure it matches the one passed in, asserting if not. @@ -2015,14 +1930,14 @@ void GIOBase::check_matching_array_size(size_t size) const { GeneratorInputBase::GeneratorInputBase(size_t array_size, const std::string &name, - IOKind kind, + ArgInfoKind kind, const std::vector &t, int d) : GIOBase(array_size, name, kind, t, d) { ObjectInstanceRegistry::register_instance(this, 0, ObjectInstanceRegistry::GeneratorInput, this, nullptr); } -GeneratorInputBase::GeneratorInputBase(const std::string &name, IOKind kind, const std::vector &t, int d) +GeneratorInputBase::GeneratorInputBase(const std::string &name, ArgInfoKind kind, const std::vector &t, int d) : GeneratorInputBase(1, name, kind, t, d) { // nothing } @@ -2048,15 +1963,19 @@ Parameter GeneratorInputBase::parameter() const { void GeneratorInputBase::verify_internals() { GIOBase::verify_internals(); - const size_t expected = (kind() != IOKind::Scalar) ? funcs().size() : exprs().size(); + const size_t expected = (kind() != ArgInfoKind::Scalar) ? funcs().size() : exprs().size(); user_assert(parameters_.size() == expected) << "Expected parameters_.size() == " << expected << ", saw " << parameters_.size() << " for " << name() << "\n"; } void GeneratorInputBase::init_internals() { + if (inputs_set) { + return; + } + // Call these for the side-effect of asserting if the values aren't defined. (void)array_size(); - (void)types(); + (void)gio_types(); (void)dims(); parameters_.clear(); @@ -2064,13 +1983,13 @@ void GeneratorInputBase::init_internals() { funcs_.clear(); for (size_t i = 0; i < array_size(); ++i) { auto name = array_name(i); - parameters_.emplace_back(type(), kind() != IOKind::Scalar, dims(), name); + parameters_.emplace_back(gio_type(), kind() != ArgInfoKind::Scalar, dims(), name); auto &p = parameters_[i]; - if (kind() != IOKind::Scalar) { + if (kind() != ArgInfoKind::Scalar) { internal_assert(dims() == p.dimensions()); funcs_.push_back(make_param_func(p, name)); } else { - Expr e = Internal::Variable::make(type(), name, p); + Expr e = Internal::Variable::make(gio_type(), name, p); exprs_.push_back(e); } } @@ -2088,14 +2007,14 @@ void GeneratorInputBase::set_inputs(const std::vector &inputs) { for (size_t i = 0; i < inputs.size(); ++i) { const StubInput &in = inputs.at(i); user_assert(in.kind() == kind()) << "An input for " << name() << " is not of the expected kind.\n"; - if (kind() == IOKind::Function) { + if (kind() == ArgInfoKind::Function) { auto f = in.func(); user_assert(f.defined()) << "The input for " << name() << " is an undefined Func. Please define it.\n"; - check_matching_types(f.output_types()); + check_matching_types(f.types()); check_matching_dims(f.dimensions()); funcs_.push_back(f); - parameters_.emplace_back(f.output_types().at(0), true, f.dimensions(), array_name(i)); - } else if (kind() == IOKind::Buffer) { + parameters_.emplace_back(f.types().at(0), true, f.dimensions(), array_name(i)); + } else if (kind() == ArgInfoKind::Buffer) { auto p = in.parameter(); user_assert(p.defined()) << "The input for " << name() << " is an undefined Buffer. Please define it.\n"; check_matching_types({p.type()}); @@ -2114,6 +2033,7 @@ void GeneratorInputBase::set_inputs(const std::vector &inputs) { set_def_min_max(); verify_internals(); + inputs_set = true; } void GeneratorInputBase::set_estimate_impl(const Var &var, const Expr &min, const Expr &extent) { @@ -2154,14 +2074,14 @@ void GeneratorInputBase::set_estimates_impl(const Region &estimates) { } } -GeneratorOutputBase::GeneratorOutputBase(size_t array_size, const std::string &name, IOKind kind, const std::vector &t, int d) +GeneratorOutputBase::GeneratorOutputBase(size_t array_size, const std::string &name, ArgInfoKind kind, const std::vector &t, int d) : GIOBase(array_size, name, kind, t, d) { - internal_assert(kind != IOKind::Scalar); + internal_assert(kind != ArgInfoKind::Scalar); ObjectInstanceRegistry::register_instance(this, 0, ObjectInstanceRegistry::GeneratorOutput, this, nullptr); } -GeneratorOutputBase::GeneratorOutputBase(const std::string &name, IOKind kind, const std::vector &t, int d) +GeneratorOutputBase::GeneratorOutputBase(const std::string &name, ArgInfoKind kind, const std::vector &t, int d) : GeneratorOutputBase(1, name, kind, t, d) { // nothing } @@ -2179,8 +2099,10 @@ void GeneratorOutputBase::init_internals() { exprs_.clear(); funcs_.clear(); if (array_size_defined()) { + const auto t = gio_types_defined() ? gio_types() : std::vector{}; + const int d = dims_defined() ? dims() : -1; for (size_t i = 0; i < array_size(); ++i) { - funcs_.emplace_back(array_name(i)); + funcs_.emplace_back(t, d, array_name(i)); } } } @@ -2195,20 +2117,16 @@ void GeneratorOutputBase::resize(size_t size) { StubOutputBufferBase::StubOutputBufferBase() = default; -StubOutputBufferBase::StubOutputBufferBase(const Func &f, const std::shared_ptr &generator) +StubOutputBufferBase::StubOutputBufferBase(const Func &f, const std::shared_ptr &generator) : f(f), generator(generator) { } -void StubOutputBufferBase::check_scheduled(const char *m) const { - generator->check_scheduled(m); -} - Realization StubOutputBufferBase::realize(std::vector sizes) { return f.realize(std::move(sizes), get_target()); } Target StubOutputBufferBase::get_target() const { - return generator->get_target(); + return generator->context().target(); } RegisterGenerator::RegisterGenerator(const char *registered_name, GeneratorFactory generator_factory) { @@ -2216,7 +2134,7 @@ RegisterGenerator::RegisterGenerator(const char *registered_name, GeneratorFacto } void generator_test() { - GeneratorContext context(get_host_target()); + GeneratorContext context(get_host_target().without_feature(Target::Profile)); // Verify that the Generator's internal phase actually prevents unsupported // order of operations. @@ -2271,75 +2189,6 @@ void generator_test() { // tester.sp2.set(202); // This will assert-fail. } -#ifdef HALIDE_ALLOW_GENERATOR_BUILD_METHOD - // Verify that the Generator's internal phase actually prevents unsupported - // order of operations (with old-style Generator) - { - class Tester : public Generator { - public: - GeneratorParam gp0{"gp0", 0}; - GeneratorParam gp1{"gp1", 1.f}; - GeneratorParam gp2{"gp2", 2}; - GeneratorParam gp_uint8{"gp_uint8", 65}; - GeneratorParam gp_int8{"gp_int8", 66}; - GeneratorParam gp_char{"gp_char", 97}; - GeneratorParam gp_schar{"gp_schar", 98}; - GeneratorParam gp_uchar{"gp_uchar", 99}; - GeneratorParam gp_bool{"gp_bool", true}; - - Input input{"input"}; - - Func build() { - internal_assert(gp0 == 1); - internal_assert(gp1 == 2.f); - internal_assert(gp2 == (uint64_t)2); // unchanged - internal_assert(gp_uint8 == 67); - internal_assert(gp_int8 == 68); - internal_assert(gp_bool == false); - internal_assert(gp_char == 107); - internal_assert(gp_schar == 108); - internal_assert(gp_uchar == 109); - Var x; - Func output; - output(x) = input + gp0; - return output; - } - }; - - Tester tester; - tester.init_from_context(context); - internal_assert(tester.phase == GeneratorBase::Created); - - // Verify that calling GeneratorParam::set() works. - tester.gp0.set(1); - - // set_inputs_vector() can't be called on an old-style Generator; - // that's OK, since we can skip from Created -> GenerateCalled anyway - // tester.set_inputs_vector({{StubInput(42)}}); - // internal_assert(tester.phase == GeneratorBase::InputsSet); - - // tester.set_inputs_vector({{StubInput(43)}}); // This will assert-fail. - - // Also ok to call in this phase. - tester.gp1.set(2.f); - - // Verify that 8-bit non-boolean GP values are parsed as integers, not chars. - tester.gp_int8.set_from_string("68"); - tester.gp_uint8.set_from_string("67"); - tester.gp_char.set_from_string("107"); - tester.gp_schar.set_from_string("108"); - tester.gp_uchar.set_from_string("109"); - tester.gp_bool.set_from_string("false"); - - tester.build_pipeline(); - internal_assert(tester.phase == GeneratorBase::ScheduleCalled); - - // tester.set_inputs_vector({{StubInput(45)}}); // This will assert-fail. - // tester.gp2.set(2); // This will assert-fail. - // tester.sp2.set(202); // This will assert-fail. - } -#endif - // Verify that set_inputs() works properly, even if the specific subtype of Generator is not known. { class Tester : public Generator { @@ -2437,11 +2286,22 @@ void generator_test() { public: GeneratorParam gp{"gp", 0}; Output output{"output", Int(32), 0}; + void generate() { + internal_assert(get_target().has_feature(Target::Profile)); output() = 0; } void schedule() { } + + // Test that we can override init_from_context() to modify the target + // we use. (Generally speaking, your code probably should ever need to + // do this; this code only does it for testing purposes. See comments + // in Generator.h.) + void init_from_context(const GeneratorContext &context) override { + auto t = context.target().with_feature(Target::Profile); + Generator::init_from_context(context.with_target(t)); + } }; GPTester gp_tester; gp_tester.init_from_context(context); diff --git a/src/Generator.h b/src/Generator.h index 98dc0940a7c5..d40fddc79141 100644 --- a/src/Generator.h +++ b/src/Generator.h @@ -208,27 +208,27 @@ * }; * \endcode * - * All Generators have three GeneratorParams that are implicitly provided + * All Generators have two GeneratorParams that are implicitly provided * by the base class: * * GeneratorParam target{"target", Target()}; - * GeneratorParam auto_schedule{"auto_schedule", false}; - * GeneratorParam machine_params{"machine_params", MachineParams::generic()}; + * GeneratorParam autoscheduler{"autoscheduler", {}} * * - 'target' is the Halide::Target for which the Generator is producing code. * It is read-only during the Generator's lifetime, and must not be modified; * its value should always be filled in by the calling code: either the Halide * build system (for ahead-of-time compilation), or ordinary C++ code * (for JIT compilation). - * - 'auto_schedule' indicates whether the auto-scheduler should be run for this - * Generator: - * - if 'false', the Generator should schedule its Funcs as it sees fit. - * - if 'true', the Generator should only provide estimate()s for its Funcs, - * and not call any other scheduling methods. - * - 'machine_params' is only used if auto_schedule is true; it is ignored - * if auto_schedule is false. It provides details about the machine architecture - * being targeted which may be used to enhance the automatically-generated - * schedule. + * - 'autoscheduler' is a string-to-string map that is used to indicates whether + * and how an auto-scheduler should be run for this Generator: + * - if empty, the Generator should schedule its Funcs as it sees fit; no autoscheduler will be run. + * - if the 'name' key is set, it should be one of the known autoschedulers + * provided with this release of Halide, which will be used to schedule + * the Funcs in the Generator. In this case, the Generator should only + * provide estimate()s for its Funcs, and not call any other scheduling methods. + * - Other keys may be specified in the params, on a per-autoscheduler + * basis, to optimize or enhance the automatically-generated schedule. + * See documentation for each autoscheduler for options. * * Generators are added to a global registry to simplify AOT build mechanics; this * is done by simply using the HALIDE_REGISTER_GENERATOR macro at global scope: @@ -271,7 +271,10 @@ #include #include +#include "AbstractGenerator.h" +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE #include "ExternalCode.h" +#endif #include "Func.h" #include "ImageParam.h" #include "Introspection.h" @@ -283,11 +286,14 @@ #endif namespace Halide { + +class GeneratorContext; + namespace Internal { void generator_test(); -class ValueTracker; +class GeneratorBase; std::vector parameter_constraints(const Parameter &p); @@ -322,10 +328,43 @@ std::string halide_type_to_c_source(const Type &t); // e.g., Int(32) -> "int32_t" std::string halide_type_to_c_type(const Type &t); +/** GeneratorFactoryProvider provides a way to customize the Generators + * that are visible to generate_filter_main (which otherwise would just + * look at the global registry of C++ Generators). */ +class GeneratorFactoryProvider { +public: + GeneratorFactoryProvider() = default; + virtual ~GeneratorFactoryProvider() = default; + + /** Return a list of all registered Generators that are available for use + * with the create() method. */ + virtual std::vector enumerate() const = 0; + + /** Create an instance of the Generator that is registered under the given + * name. If the name isn't one returned by enumerate(), return nullptr + * rather than assert-fail; caller must check for a valid result. */ + virtual AbstractGeneratorPtr create(const std::string &name, + const Halide::GeneratorContext &context) const = 0; + + GeneratorFactoryProvider(const GeneratorFactoryProvider &) = delete; + GeneratorFactoryProvider &operator=(const GeneratorFactoryProvider &) = delete; + GeneratorFactoryProvider(GeneratorFactoryProvider &&) = delete; + GeneratorFactoryProvider &operator=(GeneratorFactoryProvider &&) = delete; +}; + +/** Return a GeneratorFactoryProvider that knows about all the currently-registered C++ Generators. */ +const GeneratorFactoryProvider &get_registered_generators(); + /** generate_filter_main() is a convenient wrapper for GeneratorRegistry::create() + * compile_to_files(); it can be trivially wrapped by a "real" main() to produce a * command-line utility for ahead-of-time filter compilation. */ -int generate_filter_main(int argc, char **argv, std::ostream &cerr); +int generate_filter_main(int argc, char **argv); + +/** This overload of generate_filter_main lets you provide your own provider for how to enumerate and/or create + * the generators based on registration name; this is useful if you want to re-use the + * 'main' logic but avoid the global Generator registry (e.g. for bindings in languages + * other than C++). */ +int generate_filter_main(int argc, char **argv, const GeneratorFactoryProvider &generator_factory_provider); // select_type<> is to std::conditional as switch is to if: // it allows a multiway compile-time type definition via the form @@ -353,7 +392,6 @@ struct select_type : std::conditional struct select_type { using type = typename std::conditional::type; }; -class GeneratorBase; class GeneratorParamInfo; class GeneratorParamBase { @@ -388,7 +426,11 @@ class GeneratorParamBase { HALIDE_GENERATOR_PARAM_TYPED_SETTER(float) HALIDE_GENERATOR_PARAM_TYPED_SETTER(double) HALIDE_GENERATOR_PARAM_TYPED_SETTER(Target) +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API HALIDE_GENERATOR_PARAM_TYPED_SETTER(MachineParams) +#else + HALIDE_GENERATOR_PARAM_TYPED_SETTER(AutoschedulerParams) +#endif HALIDE_GENERATOR_PARAM_TYPED_SETTER(Type) HALIDE_GENERATOR_PARAM_TYPED_SETTER(LoopLevel) @@ -502,7 +544,11 @@ class GeneratorParamImpl : public GeneratorParamBase { HALIDE_GENERATOR_PARAM_TYPED_SETTER(float) HALIDE_GENERATOR_PARAM_TYPED_SETTER(double) HALIDE_GENERATOR_PARAM_TYPED_SETTER(Target) +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API HALIDE_GENERATOR_PARAM_TYPED_SETTER(MachineParams) +#else + HALIDE_GENERATOR_PARAM_TYPED_SETTER(AutoschedulerParams) +#endif HALIDE_GENERATOR_PARAM_TYPED_SETTER(Type) HALIDE_GENERATOR_PARAM_TYPED_SETTER(LoopLevel) @@ -596,6 +642,7 @@ class GeneratorParam_Target : public GeneratorParamImpl { } }; +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API template class GeneratorParam_MachineParams : public GeneratorParamImpl { public: @@ -621,6 +668,22 @@ class GeneratorParam_MachineParams : public GeneratorParamImpl { return "MachineParams"; } }; +#else +class GeneratorParam_AutoSchedulerParams : public GeneratorParamImpl { +public: + GeneratorParam_AutoSchedulerParams(); + + void set_from_string(const std::string &new_value_string) override; + std::string get_default_value() const override; + std::string call_to_string(const std::string &v) const override; + std::string get_c_type() const override; + +private: + friend class GeneratorBase; + + bool try_set(const std::string &key, const std::string &value); +}; +#endif class GeneratorParam_LoopLevel : public GeneratorParamImpl { public: @@ -916,7 +979,9 @@ template using GeneratorParamImplBase = typename select_type< cond::value, GeneratorParam_Target>, +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API cond::value, GeneratorParam_MachineParams>, +#endif cond::value, GeneratorParam_LoopLevel>, cond::value, GeneratorParam_String>, cond::value, GeneratorParam_Type>, @@ -1223,10 +1288,6 @@ namespace Internal { template class GeneratorInput_Buffer; -enum class IOKind { Scalar, - Function, - Buffer }; - /** * StubInputBuffer is the placeholder that a Stub uses when it requires * a Buffer for an input (rather than merely a Func or Expr). It is constructed @@ -1241,6 +1302,8 @@ class StubInputBuffer { friend class StubInput; template friend class GeneratorInput_Buffer; + template + friend class StubInputBuffer; Parameter parameter_; @@ -1273,31 +1336,45 @@ class StubInputBuffer { StubInputBuffer(const Buffer &b) : parameter_(parameter_from_buffer(b)) { } + + template + static std::vector to_parameter_vector(const StubInputBuffer &t) { + return {t.parameter_}; + } + + template + static std::vector to_parameter_vector(const std::vector> &v) { + std::vector r; + r.reserve(v.size()); + for (const auto &s : v) { + r.push_back(s.parameter_); + } + return r; + } }; +class AbstractGenerator; + class StubOutputBufferBase { protected: Func f; - std::shared_ptr generator; + std::shared_ptr generator; - void check_scheduled(const char *m) const; Target get_target() const; StubOutputBufferBase(); - explicit StubOutputBufferBase(const Func &f, const std::shared_ptr &generator); + explicit StubOutputBufferBase(const Func &f, const std::shared_ptr &generator); public: Realization realize(std::vector sizes); template Realization realize(Args &&...args) { - check_scheduled("realize"); return f.realize(std::forward(args)..., get_target()); } template void realize(Dst dst) { - check_scheduled("realize"); f.realize(dst, get_target()); } }; @@ -1318,13 +1395,21 @@ template class StubOutputBuffer : public StubOutputBufferBase { template friend class GeneratorOutput_Buffer; - friend class GeneratorStub; - explicit StubOutputBuffer(const Func &f, const std::shared_ptr &generator) - : StubOutputBufferBase(f, generator) { + explicit StubOutputBuffer(const Func &fn, const std::shared_ptr &gen) + : StubOutputBufferBase(fn, gen) { } public: StubOutputBuffer() = default; + + static std::vector> to_output_buffers(const std::vector &v, + const std::shared_ptr &gen) { + std::vector> result; + for (const Func &f : v) { + result.push_back(StubOutputBuffer(f, gen)); + } + return result; + } }; // This is a union-like class that allows for convenient initialization of Stub Inputs @@ -1332,7 +1417,7 @@ class StubOutputBuffer : public StubOutputBufferBase { // downstream consumer will be able to explicitly check that each value is // of the expected/required kind. class StubInput { - const IOKind kind_; + const ArgInfoKind kind_; // Exactly one of the following fields should be defined: const Parameter parameter_; const Func func_; @@ -1342,34 +1427,34 @@ class StubInput { // *not* explicit. template StubInput(const StubInputBuffer &b) - : kind_(IOKind::Buffer), parameter_(b.parameter_), func_(), expr_() { + : kind_(ArgInfoKind::Buffer), parameter_(b.parameter_), func_(), expr_() { + } + StubInput(const Parameter &p) + : kind_(ArgInfoKind::Buffer), parameter_(p), func_(), expr_() { } StubInput(const Func &f) - : kind_(IOKind::Function), parameter_(), func_(f), expr_() { + : kind_(ArgInfoKind::Function), parameter_(), func_(f), expr_() { } StubInput(const Expr &e) - : kind_(IOKind::Scalar), parameter_(), func_(), expr_(e) { + : kind_(ArgInfoKind::Scalar), parameter_(), func_(), expr_(e) { } -private: - friend class GeneratorInputBase; - - IOKind kind() const { + ArgInfoKind kind() const { return kind_; } Parameter parameter() const { - internal_assert(kind_ == IOKind::Buffer); + internal_assert(kind_ == ArgInfoKind::Buffer); return parameter_; } Func func() const { - internal_assert(kind_ == IOKind::Function); + internal_assert(kind_ == ArgInfoKind::Function); return func_; } Expr expr() const { - internal_assert(kind_ == IOKind::Scalar); + internal_assert(kind_ == ArgInfoKind::Scalar); return expr_; } }; @@ -1394,16 +1479,25 @@ class StubInput { */ class GIOBase { public: + virtual ~GIOBase() = default; + + // These should only be called from configure() methods. + // TODO: find a way to enforce this. Better yet, find a way to remove these. + void set_type(const Type &type); + void set_dimensions(int dims); + void set_array_size(int size); + +protected: bool array_size_defined() const; size_t array_size() const; virtual bool is_array() const; const std::string &name() const; - IOKind kind() const; + ArgInfoKind kind() const; - bool types_defined() const; - const std::vector &types() const; - Type type() const; + bool gio_types_defined() const; + const std::vector &gio_types() const; + Type gio_type() const; bool dims_defined() const; int dims() const; @@ -1411,16 +1505,9 @@ class GIOBase { const std::vector &funcs() const; const std::vector &exprs() const; - virtual ~GIOBase() = default; - - void set_type(const Type &type); - void set_dimensions(int dims); - void set_array_size(int size); - -protected: GIOBase(size_t array_size, const std::string &name, - IOKind kind, + ArgInfoKind kind, const std::vector &types, int dims); @@ -1431,7 +1518,7 @@ class GIOBase { // -1 if is_array() == true but unspecified. const std::string name_; - const IOKind kind_; + const ArgInfoKind kind_; mutable std::vector types_; // empty if type is unspecified mutable int dims_; // -1 if dim is unspecified @@ -1466,6 +1553,7 @@ class GIOBase { private: template friend class GeneratorParam_Synthetic; + friend class GeneratorStub; public: GIOBase(const GIOBase &) = delete; @@ -1488,11 +1576,11 @@ class GeneratorInputBase : public GIOBase { protected: GeneratorInputBase(size_t array_size, const std::string &name, - IOKind kind, + ArgInfoKind kind, const std::vector &t, int d); - GeneratorInputBase(const std::string &name, IOKind kind, const std::vector &t, int d); + GeneratorInputBase(const std::string &name, ArgInfoKind kind, const std::vector &t, int d); friend class GeneratorBase; friend class GeneratorParamInfo; @@ -1503,6 +1591,7 @@ class GeneratorInputBase : public GIOBase { void init_internals(); void set_inputs(const std::vector &inputs); + bool inputs_set = false; virtual void set_def_min_max(); @@ -1537,21 +1626,21 @@ class GeneratorInputImpl : public GeneratorInputBase { template::value>::type * = nullptr> - GeneratorInputImpl(const std::string &name, IOKind kind, const std::vector &t, int d) + GeneratorInputImpl(const std::string &name, ArgInfoKind kind, const std::vector &t, int d) : GeneratorInputBase(name, kind, t, d) { } template::value && std::rank::value == 1 && (std::extent::value > 0)>::type * = nullptr> - GeneratorInputImpl(const std::string &name, IOKind kind, const std::vector &t, int d) + GeneratorInputImpl(const std::string &name, ArgInfoKind kind, const std::vector &t, int d) : GeneratorInputBase(std::extent::value, name, kind, t, d) { } template::value && std::rank::value == 1 && std::extent::value == 0>::type * = nullptr> - GeneratorInputImpl(const std::string &name, IOKind kind, const std::vector &t, int d) + GeneratorInputImpl(const std::string &name, ArgInfoKind kind, const std::vector &t, int d) : GeneratorInputBase(-1, name, kind, t, d) { } @@ -1637,24 +1726,24 @@ class GeneratorInput_Buffer : public GeneratorInputImpl { public: explicit GeneratorInput_Buffer(const std::string &name) - : Super(name, IOKind::Buffer, + : Super(name, ArgInfoKind::Buffer, TBase::has_static_halide_type ? std::vector{TBase::static_halide_type()} : std::vector{}, TBase::has_static_dimensions ? TBase::static_dimensions() : -1) { } GeneratorInput_Buffer(const std::string &name, const Type &t, int d) - : Super(name, IOKind::Buffer, {t}, d) { + : Super(name, ArgInfoKind::Buffer, {t}, d) { static_assert(!TBase::has_static_halide_type, "You can only specify a Type argument for Input> if T is void or omitted."); static_assert(!TBase::has_static_dimensions, "You can only specify a dimension argument for Input> if D is -1 or omitted."); } GeneratorInput_Buffer(const std::string &name, const Type &t) - : Super(name, IOKind::Buffer, {t}, -1) { + : Super(name, ArgInfoKind::Buffer, {t}, -1) { static_assert(!TBase::has_static_halide_type, "You can only specify a Type argument for Input> if T is void or omitted."); } GeneratorInput_Buffer(const std::string &name, int d) - : Super(name, IOKind::Buffer, + : Super(name, ArgInfoKind::Buffer, TBase::has_static_halide_type ? std::vector{TBase::static_halide_type()} : std::vector{}, d) { static_assert(!TBase::has_static_dimensions, "You can only specify a dimension argument for Input> if D is -1 or omitted."); @@ -1767,6 +1856,7 @@ class GeneratorInput_Buffer : public GeneratorInputImpl { HALIDE_FORWARD_METHOD_CONST(ImageParam, channels) HALIDE_FORWARD_METHOD_CONST(ImageParam, trace_loads) HALIDE_FORWARD_METHOD_CONST(ImageParam, add_trace_tag) + HALIDE_FORWARD_METHOD_CONST(ImageParam, type) // }@ }; @@ -1789,41 +1879,41 @@ class GeneratorInput_Func : public GeneratorInputImpl { public: GeneratorInput_Func(const std::string &name, const Type &t, int d) - : Super(name, IOKind::Function, {t}, d) { + : Super(name, ArgInfoKind::Function, {t}, d) { } // unspecified type GeneratorInput_Func(const std::string &name, int d) - : Super(name, IOKind::Function, {}, d) { + : Super(name, ArgInfoKind::Function, {}, d) { } // unspecified dimension GeneratorInput_Func(const std::string &name, const Type &t) - : Super(name, IOKind::Function, {t}, -1) { + : Super(name, ArgInfoKind::Function, {t}, -1) { } // unspecified type & dimension explicit GeneratorInput_Func(const std::string &name) - : Super(name, IOKind::Function, {}, -1) { + : Super(name, ArgInfoKind::Function, {}, -1) { } GeneratorInput_Func(size_t array_size, const std::string &name, const Type &t, int d) - : Super(array_size, name, IOKind::Function, {t}, d) { + : Super(array_size, name, ArgInfoKind::Function, {t}, d) { } // unspecified type GeneratorInput_Func(size_t array_size, const std::string &name, int d) - : Super(array_size, name, IOKind::Function, {}, d) { + : Super(array_size, name, ArgInfoKind::Function, {}, d) { } // unspecified dimension GeneratorInput_Func(size_t array_size, const std::string &name, const Type &t) - : Super(array_size, name, IOKind::Function, {t}, -1) { + : Super(array_size, name, ArgInfoKind::Function, {t}, -1) { } // unspecified type & dimension GeneratorInput_Func(size_t array_size, const std::string &name) - : Super(array_size, name, IOKind::Function, {}, -1) { + : Super(array_size, name, ArgInfoKind::Function, {}, -1) { } template @@ -1879,11 +1969,23 @@ class GeneratorInput_Func : public GeneratorInputImpl { // @{ HALIDE_FORWARD_METHOD_CONST(Func, args) HALIDE_FORWARD_METHOD_CONST(Func, defined) + HALIDE_FORWARD_METHOD_CONST(Func, dimensions) HALIDE_FORWARD_METHOD_CONST(Func, has_update_definition) HALIDE_FORWARD_METHOD_CONST(Func, num_update_definitions) - HALIDE_FORWARD_METHOD_CONST(Func, output_types) + HALIDE_ATTRIBUTE_DEPRECATED("Func::output_type() is deprecated; use Func::type() instead.") + const Type &output_type() const { + this->check_gio_access(); + return this->as().type(); + } + HALIDE_ATTRIBUTE_DEPRECATED("Func::output_types() is deprecated; use Func::types() instead.") + const std::vector &output_types() const { + this->check_gio_access(); + return this->as().types(); + } HALIDE_FORWARD_METHOD_CONST(Func, outputs) HALIDE_FORWARD_METHOD_CONST(Func, rvars) + HALIDE_FORWARD_METHOD_CONST(Func, type) + HALIDE_FORWARD_METHOD_CONST(Func, types) HALIDE_FORWARD_METHOD_CONST(Func, update_args) HALIDE_FORWARD_METHOD_CONST(Func, update_value) HALIDE_FORWARD_METHOD_CONST(Func, update_values) @@ -1906,7 +2008,7 @@ class GeneratorInput_DynamicScalar : public GeneratorInputImpl { public: explicit GeneratorInput_DynamicScalar(const std::string &name) - : Super(name, IOKind::Scalar, {}, 0) { + : Super(name, ArgInfoKind::Scalar, {}, 0) { user_assert(!std::is_array::value) << "Input is not allowed"; } @@ -1930,6 +2032,10 @@ class GeneratorInput_DynamicScalar : public GeneratorInputImpl { p.set_estimate(value); } } + + Type type() const { + return Expr(*this).type(); + } }; template @@ -1969,22 +2075,22 @@ class GeneratorInput_Scalar : public GeneratorInputImpl { public: explicit GeneratorInput_Scalar(const std::string &name) - : Super(name, IOKind::Scalar, {type_of()}, 0), def_(static_cast(0)), def_expr_(Expr()) { + : Super(name, ArgInfoKind::Scalar, {type_of()}, 0), def_(static_cast(0)), def_expr_(Expr()) { } GeneratorInput_Scalar(const std::string &name, const TBase &def) - : Super(name, IOKind::Scalar, {type_of()}, 0), def_(def), def_expr_(TBaseToExpr(def)) { + : Super(name, ArgInfoKind::Scalar, {type_of()}, 0), def_(def), def_expr_(TBaseToExpr(def)) { } GeneratorInput_Scalar(size_t array_size, const std::string &name) - : Super(array_size, name, IOKind::Scalar, {type_of()}, 0), def_(static_cast(0)), def_expr_(Expr()) { + : Super(array_size, name, ArgInfoKind::Scalar, {type_of()}, 0), def_(static_cast(0)), def_expr_(Expr()) { } GeneratorInput_Scalar(size_t array_size, const std::string &name, const TBase &def) - : Super(array_size, name, IOKind::Scalar, {type_of()}, 0), def_(def), def_expr_(TBaseToExpr(def)) { + : Super(array_size, name, ArgInfoKind::Scalar, {type_of()}, 0), def_(def), def_expr_(TBaseToExpr(def)) { } /** You can use this Input as an expression in a halide @@ -2032,6 +2138,10 @@ class GeneratorInput_Scalar : public GeneratorInputImpl { } this->parameters_.at(index).set_estimate(e); } + + Type type() const { + return Expr(*this).type(); + } }; template @@ -2199,8 +2309,9 @@ class GeneratorOutputBase : public GIOBase { template::value>::type * = nullptr> HALIDE_NO_USER_CODE_INLINE T2 as() const { static_assert(std::is_same::value, "Only Func allowed here"); - internal_assert(kind() != IOKind::Scalar); + internal_assert(kind() != ArgInfoKind::Scalar); internal_assert(exprs_.empty()); + user_assert(!funcs_.empty()) << "No funcs_ are defined yet"; user_assert(funcs_.size() == 1) << "Use [] to access individual Funcs in Output"; return funcs_[0]; } @@ -2223,6 +2334,7 @@ class GeneratorOutputBase : public GIOBase { HALIDE_FORWARD_METHOD(Func, copy_to_host) HALIDE_FORWARD_METHOD(Func, define_extern) HALIDE_FORWARD_METHOD_CONST(Func, defined) + HALIDE_FORWARD_METHOD_CONST(Func, dimensions) HALIDE_FORWARD_METHOD(Func, fold_storage) HALIDE_FORWARD_METHOD(Func, fuse) HALIDE_FORWARD_METHOD(Func, gpu) @@ -2235,7 +2347,16 @@ class GeneratorOutputBase : public GIOBase { HALIDE_FORWARD_METHOD(Func, in) HALIDE_FORWARD_METHOD(Func, memoize) HALIDE_FORWARD_METHOD_CONST(Func, num_update_definitions) - HALIDE_FORWARD_METHOD_CONST(Func, output_types) + HALIDE_ATTRIBUTE_DEPRECATED("Func::output_type() is deprecated; use Func::type() instead.") + const Type &output_type() const { + this->check_gio_access(); + return this->as().type(); + } + HALIDE_ATTRIBUTE_DEPRECATED("Func::output_types() is deprecated; use Func::types() instead.") + const std::vector &output_types() const { + this->check_gio_access(); + return this->as().types(); + } HALIDE_FORWARD_METHOD_CONST(Func, outputs) HALIDE_FORWARD_METHOD(Func, parallel) HALIDE_FORWARD_METHOD(Func, prefetch) @@ -2253,6 +2374,8 @@ class GeneratorOutputBase : public GIOBase { HALIDE_FORWARD_METHOD(Func, store_root) HALIDE_FORWARD_METHOD(Func, tile) HALIDE_FORWARD_METHOD(Func, trace_stores) + HALIDE_FORWARD_METHOD_CONST(Func, type) + HALIDE_FORWARD_METHOD_CONST(Func, types) HALIDE_FORWARD_METHOD(Func, unroll) HALIDE_FORWARD_METHOD(Func, update) HALIDE_FORWARD_METHOD_CONST(Func, update_args) @@ -2261,6 +2384,7 @@ class GeneratorOutputBase : public GIOBase { HALIDE_FORWARD_METHOD_CONST(Func, value) HALIDE_FORWARD_METHOD_CONST(Func, values) HALIDE_FORWARD_METHOD(Func, vectorize) + // }@ #undef HALIDE_OUTPUT_FORWARD @@ -2269,12 +2393,12 @@ class GeneratorOutputBase : public GIOBase { protected: GeneratorOutputBase(size_t array_size, const std::string &name, - IOKind kind, + ArgInfoKind kind, const std::vector &t, int d); GeneratorOutputBase(const std::string &name, - IOKind kind, + ArgInfoKind kind, const std::vector &t, int d); @@ -2311,21 +2435,21 @@ class GeneratorOutputImpl : public GeneratorOutputBase { template::value>::type * = nullptr> - GeneratorOutputImpl(const std::string &name, IOKind kind, const std::vector &t, int d) + GeneratorOutputImpl(const std::string &name, ArgInfoKind kind, const std::vector &t, int d) : GeneratorOutputBase(name, kind, t, d) { } template::value && std::rank::value == 1 && (std::extent::value > 0)>::type * = nullptr> - GeneratorOutputImpl(const std::string &name, IOKind kind, const std::vector &t, int d) + GeneratorOutputImpl(const std::string &name, ArgInfoKind kind, const std::vector &t, int d) : GeneratorOutputBase(std::extent::value, name, kind, t, d) { } template::value && std::rank::value == 1 && std::extent::value == 0>::type * = nullptr> - GeneratorOutputImpl(const std::string &name, IOKind kind, const std::vector &t, int d) + GeneratorOutputImpl(const std::string &name, ArgInfoKind kind, const std::vector &t, int d) : GeneratorOutputBase(-1, name, kind, t, d) { } @@ -2403,24 +2527,24 @@ class GeneratorOutput_Buffer : public GeneratorOutputImpl { internal_assert(f.defined()); - if (this->types_defined()) { - const auto &my_types = this->types(); - user_assert(my_types.size() == f.output_types().size()) + if (this->gio_types_defined()) { + const auto &my_types = this->gio_types(); + user_assert(my_types.size() == f.types().size()) << "Cannot assign Func \"" << f.name() << "\" to Output \"" << this->name() << "\"\n" << "Output " << this->name() << " is declared to have " << my_types.size() << " tuple elements" << " but Func " << f.name() - << " has " << f.output_types().size() << " tuple elements.\n"; + << " has " << f.types().size() << " tuple elements.\n"; for (size_t i = 0; i < my_types.size(); i++) { - user_assert(my_types[i] == f.output_types().at(i)) + user_assert(my_types[i] == f.types().at(i)) << "Cannot assign Func \"" << f.name() << "\" to Output \"" << this->name() << "\"\n" << (my_types.size() > 1 ? "In tuple element " + std::to_string(i) + ", " : "") << "Output " << this->name() << " has declared type " << my_types[i] << " but Func " << f.name() - << " has type " << f.output_types().at(i) << "\n"; + << " has type " << f.types().at(i) << "\n"; } } if (this->dims_defined()) { @@ -2442,13 +2566,13 @@ class GeneratorOutput_Buffer : public GeneratorOutputImpl { using TBase = typename Super::TBase; explicit GeneratorOutput_Buffer(const std::string &name) - : Super(name, IOKind::Buffer, + : Super(name, ArgInfoKind::Buffer, TBase::has_static_halide_type ? std::vector{TBase::static_halide_type()} : std::vector{}, TBase::has_static_dimensions ? TBase::static_dimensions() : -1) { } GeneratorOutput_Buffer(const std::string &name, const std::vector &t, int d) - : Super(name, IOKind::Buffer, t, d) { + : Super(name, ArgInfoKind::Buffer, t, d) { internal_assert(!t.empty()); internal_assert(d != -1); static_assert(!TBase::has_static_halide_type, "You can only specify a Type argument for Output> if T is void or omitted."); @@ -2456,13 +2580,13 @@ class GeneratorOutput_Buffer : public GeneratorOutputImpl { } GeneratorOutput_Buffer(const std::string &name, const std::vector &t) - : Super(name, IOKind::Buffer, t, -1) { + : Super(name, ArgInfoKind::Buffer, t, -1) { internal_assert(!t.empty()); static_assert(!TBase::has_static_halide_type, "You can only specify a Type argument for Output> if T is void or omitted."); } GeneratorOutput_Buffer(const std::string &name, int d) - : Super(name, IOKind::Buffer, + : Super(name, ArgInfoKind::Buffer, TBase::has_static_halide_type ? std::vector{TBase::static_halide_type()} : std::vector{}, d) { internal_assert(d != -1); @@ -2470,13 +2594,13 @@ class GeneratorOutput_Buffer : public GeneratorOutputImpl { } GeneratorOutput_Buffer(size_t array_size, const std::string &name) - : Super(array_size, name, IOKind::Buffer, + : Super(array_size, name, ArgInfoKind::Buffer, TBase::has_static_halide_type ? std::vector{TBase::static_halide_type()} : std::vector{}, TBase::has_static_dimensions ? TBase::static_dimensions() : -1) { } GeneratorOutput_Buffer(size_t array_size, const std::string &name, const std::vector &t, int d) - : Super(array_size, name, IOKind::Buffer, t, d) { + : Super(array_size, name, ArgInfoKind::Buffer, t, d) { internal_assert(!t.empty()); internal_assert(d != -1); static_assert(!TBase::has_static_halide_type, "You can only specify a Type argument for Output> if T is void or omitted."); @@ -2484,13 +2608,13 @@ class GeneratorOutput_Buffer : public GeneratorOutputImpl { } GeneratorOutput_Buffer(size_t array_size, const std::string &name, const std::vector &t) - : Super(array_size, name, IOKind::Buffer, t, -1) { + : Super(array_size, name, ArgInfoKind::Buffer, t, -1) { internal_assert(!t.empty()); static_assert(!TBase::has_static_halide_type, "You can only specify a Type argument for Output> if T is void or omitted."); } GeneratorOutput_Buffer(size_t array_size, const std::string &name, int d) - : Super(array_size, name, IOKind::Buffer, + : Super(array_size, name, ArgInfoKind::Buffer, TBase::has_static_halide_type ? std::vector{TBase::static_halide_type()} : std::vector{}, d) { internal_assert(d != -1); @@ -2528,9 +2652,9 @@ class GeneratorOutput_Buffer : public GeneratorOutputImpl { << "Cannot assign to the Output \"" << this->name() << "\": the expression is not convertible to the same Buffer type and/or dimensions.\n"; - if (this->types_defined()) { - user_assert(Type(buffer.type()) == this->type()) - << "Output " << this->name() << " should have type=" << this->type() << " but saw type=" << Type(buffer.type()) << "\n"; + if (this->gio_types_defined()) { + user_assert(Type(buffer.type()) == this->gio_type()) + << "Output " << this->name() << " should have type=" << this->gio_type() << " but saw type=" << Type(buffer.type()) << "\n"; } if (this->dims_defined()) { user_assert(buffer.dimensions() == this->dims()) @@ -2624,23 +2748,23 @@ class GeneratorOutput_Func : public GeneratorOutputImpl { using TBase = typename Super::TBase; explicit GeneratorOutput_Func(const std::string &name) - : Super(name, IOKind::Function, std::vector{}, -1) { + : Super(name, ArgInfoKind::Function, std::vector{}, -1) { } GeneratorOutput_Func(const std::string &name, const std::vector &t, int d) - : Super(name, IOKind::Function, t, d) { + : Super(name, ArgInfoKind::Function, t, d) { } GeneratorOutput_Func(const std::string &name, const std::vector &t) - : Super(name, IOKind::Function, t, -1) { + : Super(name, ArgInfoKind::Function, t, -1) { } GeneratorOutput_Func(const std::string &name, int d) - : Super(name, IOKind::Function, {}, d) { + : Super(name, ArgInfoKind::Function, {}, d) { } GeneratorOutput_Func(size_t array_size, const std::string &name, const std::vector &t, int d) - : Super(array_size, name, IOKind::Function, t, d) { + : Super(array_size, name, ArgInfoKind::Function, t, d) { } public: @@ -2699,11 +2823,11 @@ class GeneratorOutput_Arithmetic : public GeneratorOutputImpl { using TBase = typename Super::TBase; explicit GeneratorOutput_Arithmetic(const std::string &name) - : Super(name, IOKind::Function, {type_of()}, 0) { + : Super(name, ArgInfoKind::Function, {type_of()}, 0) { } GeneratorOutput_Arithmetic(size_t array_size, const std::string &name) - : Super(array_size, name, IOKind::Function, {type_of()}, 0) { + : Super(array_size, name, ArgInfoKind::Function, {type_of()}, 0) { } }; @@ -2892,8 +3016,6 @@ class GeneratorParam_Synthetic : public GeneratorParamImpl { const std::string error_msg; }; -class GeneratorStub; - } // namespace Internal /** GeneratorContext is a class that is used when using Generators (or Stubs) directly; @@ -2914,7 +3036,7 @@ class GeneratorStub; * \endcode * * Note that all Generators embed a GeneratorContext, so if you are using a Stub - * from within a Generator, you can just pass 'contex()' for the GeneratorContext: + * from within a Generator, you can just pass 'context()' for the GeneratorContext: * \code * struct SomeGen : Generator { * void generate() { @@ -2935,11 +3057,19 @@ class GeneratorContext { public: friend class Internal::GeneratorBase; +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE using ExternsMap = std::map; +#endif +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API explicit GeneratorContext(const Target &t, bool auto_schedule = false, const MachineParams &machine_params = MachineParams::generic()); +#else + explicit GeneratorContext(const Target &t); + explicit GeneratorContext(const Target &t, + const AutoschedulerParams &autoscheduler_params); +#endif GeneratorContext() = default; GeneratorContext(const GeneratorContext &) = default; @@ -2947,15 +3077,41 @@ class GeneratorContext { GeneratorContext(GeneratorContext &&) = default; GeneratorContext &operator=(GeneratorContext &&) = default; + const Target &target() const { + return target_; + } +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + bool auto_schedule() const { + return auto_schedule_; + } + const MachineParams &machine_params() const { + return machine_params_; + } +#else + const AutoschedulerParams &autoscheduler_params() const { + return autoscheduler_params_; + } +#endif + + HALIDE_ATTRIBUTE_DEPRECATED("Call GeneratorContext::target() instead of GeneratorContext::get_target().") const Target &get_target() const { return target_; } +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + HALIDE_ATTRIBUTE_DEPRECATED("Call GeneratorContext::auto_schedule() instead of GeneratorContext::get_auto_schedule().") bool get_auto_schedule() const { return auto_schedule_; } + HALIDE_ATTRIBUTE_DEPRECATED("Call GeneratorContext::machine_params() instead of GeneratorContext::get_machine_params().") const MachineParams &get_machine_params() const { return machine_params_; } +#endif + + // Return a copy of this GeneratorContext that uses the given Target. + // This method is rarely needed; it's really provided as a convenience + // for use with init_from_context(). + GeneratorContext with_target(const Target &t) const; template inline std::unique_ptr create() const { @@ -2970,16 +3126,26 @@ class GeneratorContext { private: Target target_; +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API bool auto_schedule_ = false; MachineParams machine_params_ = MachineParams::generic(); - std::shared_ptr externs_map_; - std::shared_ptr value_tracker_; +#else + AutoschedulerParams autoscheduler_params_; +#endif +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE + std::shared_ptr externs_map_ = std::make_shared(); +#endif +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE GeneratorContext(const Target &target, +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API bool auto_schedule, const MachineParams &machine_params, - std::shared_ptr externs_map, - std::shared_ptr value_tracker); +#else + const AutoschedulerParams &autoscheduler_params, +#endif + std::shared_ptr externs_map); +#endif // HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE }; class NamesInterface { @@ -3045,28 +3211,9 @@ struct NoRealizations { static const bool value = !std::is_convertible::value && NoRealizations::value; }; -class GeneratorStub; - // Note that these functions must never return null: // if they cannot return a valid Generator, they must assert-fail. -using GeneratorFactory = std::function(const GeneratorContext &)>; - -struct StringOrLoopLevel { - std::string string_value; - LoopLevel loop_level; - - StringOrLoopLevel() = default; - /*not-explicit*/ StringOrLoopLevel(const char *s) - : string_value(s) { - } - /*not-explicit*/ StringOrLoopLevel(const std::string &s) - : string_value(s) { - } - /*not-explicit*/ StringOrLoopLevel(const LoopLevel &loop_level) - : loop_level(loop_level) { - } -}; -using GeneratorParamsMap = std::map; +using GeneratorFactory = std::function; class GeneratorParamInfo { // names used across all params, inputs, and outputs. @@ -3105,11 +3252,9 @@ class GeneratorParamInfo { } }; -class GeneratorBase : public NamesInterface { +class GeneratorBase : public NamesInterface, public AbstractGenerator { public: - virtual ~GeneratorBase(); - - void set_generator_param_values(const GeneratorParamsMap ¶ms); + ~GeneratorBase() override; /** Given a data type, return an estimate of the "natural" vector size * for that data type when compiling for the current target. */ @@ -3124,29 +3269,6 @@ class GeneratorBase : public NamesInterface { return get_target().natural_vector_size(); } - void emit_cpp_stub(const std::string &stub_file_path); - - // Call generate() and produce a Module for the result. - // If function_name is empty, generator_name() will be used for the function. - Module build_module(const std::string &function_name = "", - LinkageType linkage_type = LinkageType::ExternalPlusMetadata); - - /** - * Build a module that is suitable for using for gradient descent calculation in TensorFlow or PyTorch. - * - * Essentially: - * - A new Pipeline is synthesized from the current Generator (according to the rules below) - * - The new Pipeline is autoscheduled (if autoscheduling is requested, but it would be odd not to do so) - * - The Pipeline is compiled to a Module and returned - * - * The new Pipeline is adjoint to the original; it has: - * - All the same inputs as the original, in the same order - * - Followed by one grad-input for each original output - * - Followed by one output for each unique pairing of original-output + original-input. - * (For the common case of just one original-output, this amounts to being one output for each original-input.) - */ - Module build_gradient_module(const std::string &function_name); - /** * set_inputs is a variadic wrapper around set_inputs_vector, which makes usage much simpler * in many cases, as it constructs the relevant entries for the vector for you, which @@ -3185,20 +3307,11 @@ class GeneratorBase : public NamesInterface { get_pipeline().realize(r, get_target()); } -#ifdef HALIDE_ALLOW_GENERATOR_BUILD_METHOD - // Return the Pipeline that has been built by the generate() method. - // This method can only be used from a Generator that has a generate() - // method (vs a build() method), and currently can only be called from - // the schedule() method. (This may be relaxed in the future to allow - // calling from generate() as long as all Outputs have been defined.) - Pipeline get_pipeline(); -#else // Return the Pipeline that has been built by the generate() method. // This method can only be called from the schedule() method. // (This may be relaxed in the future to allow calling from generate() as // long as all Outputs have been defined.) Pipeline get_pipeline(); -#endif // Create Input with dynamic type & dimensions templateGenerateCalled directly.) InputsSet, -#ifdef HALIDE_ALLOW_GENERATOR_BUILD_METHOD - // Generator has had its generate() method called. (For Generators with - // a build() method instead of generate(), this phase will be skipped - // and will advance directly to ScheduleCalled.) - GenerateCalled, -#else // Generator has had its generate() method called. GenerateCalled, -#endif // Generator has had its schedule() method (if any) called. ScheduleCalled, @@ -3409,12 +3565,23 @@ class GeneratorBase : public NamesInterface { Target get_target() const { return target; } +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API bool get_auto_schedule() const { return auto_schedule; } MachineParams get_machine_params() const { return machine_params; } + bool using_autoscheduler() const { + return get_auto_schedule(); + } +#else + bool using_autoscheduler() const { + return !autoscheduler_.value().name.empty(); + } +#endif + +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE /** Generators can register ExternalCode objects onto * themselves. The Generator infrastructure will arrange to have * this ExternalCode appended to the Module that is finally @@ -3433,11 +3600,20 @@ class GeneratorBase : public NamesInterface { std::shared_ptr get_externs_map() const { return externs_map; } +#else + /** ExternalCode objects in Generator are deprecated in Halide 15 and will + * be removed in Halide 16. You may continue to use them in Halide 15 + * by defining HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE when building Halide. */ +#endif // These must remain here for legacy code that access the fields directly. GeneratorParam target{"target", Target()}; +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API GeneratorParam auto_schedule{"auto_schedule", false}; GeneratorParam machine_params{"machine_params", MachineParams::generic()}; +#else + GeneratorParam_AutoSchedulerParams autoscheduler_; +#endif private: friend void ::Halide::Internal::generator_test(); @@ -3446,24 +3622,34 @@ class GeneratorBase : public NamesInterface { friend class GeneratorInputBase; friend class GeneratorOutputBase; friend class GeneratorParamInfo; - friend class GeneratorStub; friend class StubOutputBufferBase; const size_t size; +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE std::shared_ptr externs_map; - std::shared_ptr value_tracker; +#endif // Lazily-allocated-and-inited struct with info about our various Params. // Do not access directly: use the param_info() getter. std::unique_ptr param_info_ptr; - bool inputs_set{false}; std::string generator_registered_name, generator_stub_name; Pipeline pipeline; // Return our GeneratorParamInfo. GeneratorParamInfo ¶m_info(); + template + T *find_by_name(const std::string &name, const std::vector &v) { + for (T *t : v) { + if (t->name() == name) { + return t; + } + } + return nullptr; + } + + Internal::GeneratorInputBase *find_input_by_name(const std::string &name); Internal::GeneratorOutputBase *find_output_by_name(const std::string &name); void check_scheduled(const char *m) const; @@ -3477,17 +3663,11 @@ class GeneratorBase : public NamesInterface { void get_jit_target_from_environment(); void get_target_from_environment(); - // Return the output with the given name. - // If the output is singular (a non-array), return a vector of size 1. - // If no such name exists (or is non-array), assert. - // This method never returns undefined Funcs. - std::vector get_outputs(const std::string &n); - void set_inputs_vector(const std::vector> &inputs); static void check_input_is_singular(Internal::GeneratorInputBase *in); static void check_input_is_array(Internal::GeneratorInputBase *in); - static void check_input_kind(Internal::GeneratorInputBase *in, Internal::IOKind kind); + static void check_input_kind(Internal::GeneratorInputBase *in, Internal::ArgInfoKind kind); // Allow Buffer<> if: // -- we are assigning it to an Input> (with compatible type and dimensions), @@ -3498,18 +3678,18 @@ class GeneratorBase : public NamesInterface { auto *in = param_info().inputs().at(i); check_input_is_singular(in); const auto k = in->kind(); - if (k == Internal::IOKind::Buffer) { + if (k == Internal::ArgInfoKind::Buffer) { Halide::Buffer<> b = arg; StubInputBuffer<> sib(b); StubInput si(sib); return {si}; - } else if (k == Internal::IOKind::Function) { + } else if (k == Internal::ArgInfoKind::Function) { Halide::Func f(arg.name() + "_im"); f(Halide::_) = arg(Halide::_); StubInput si(f); return {si}; } else { - check_input_kind(in, Internal::IOKind::Buffer); // just to trigger assertion + check_input_kind(in, Internal::ArgInfoKind::Buffer); // just to trigger assertion return {}; } } @@ -3523,16 +3703,16 @@ class GeneratorBase : public NamesInterface { auto *in = param_info().inputs().at(i); check_input_is_singular(in); const auto k = in->kind(); - if (k == Internal::IOKind::Buffer) { + if (k == Internal::ArgInfoKind::Buffer) { StubInputBuffer<> sib = arg; StubInput si(sib); return {si}; - } else if (k == Internal::IOKind::Function) { + } else if (k == Internal::ArgInfoKind::Function) { Halide::Func f = arg.funcs().at(0); StubInput si(f); return {si}; } else { - check_input_kind(in, Internal::IOKind::Buffer); // just to trigger assertion + check_input_kind(in, Internal::ArgInfoKind::Buffer); // just to trigger assertion return {}; } } @@ -3540,7 +3720,7 @@ class GeneratorBase : public NamesInterface { // Allow Func iff we are assigning it to an Input (with compatible type and dimensions). std::vector build_input(size_t i, const Func &arg) { auto *in = param_info().inputs().at(i); - check_input_kind(in, Internal::IOKind::Function); + check_input_kind(in, Internal::ArgInfoKind::Function); check_input_is_singular(in); const Halide::Func &f = arg; StubInput si(f); @@ -3550,7 +3730,7 @@ class GeneratorBase : public NamesInterface { // Allow vector iff we are assigning it to an Input (with compatible type and dimensions). std::vector build_input(size_t i, const std::vector &arg) { auto *in = param_info().inputs().at(i); - check_input_kind(in, Internal::IOKind::Function); + check_input_kind(in, Internal::ArgInfoKind::Function); check_input_is_array(in); // My kingdom for a list comprehension... std::vector siv; @@ -3564,7 +3744,7 @@ class GeneratorBase : public NamesInterface { // Expr must be Input. std::vector build_input(size_t i, const Expr &arg) { auto *in = param_info().inputs().at(i); - check_input_kind(in, Internal::IOKind::Scalar); + check_input_kind(in, Internal::ArgInfoKind::Scalar); check_input_is_singular(in); StubInput si(arg); return {si}; @@ -3573,7 +3753,7 @@ class GeneratorBase : public NamesInterface { // (Array form) std::vector build_input(size_t i, const std::vector &arg) { auto *in = param_info().inputs().at(i); - check_input_kind(in, Internal::IOKind::Scalar); + check_input_kind(in, Internal::ArgInfoKind::Scalar); check_input_is_array(in); std::vector siv; siv.reserve(arg.size()); @@ -3589,7 +3769,7 @@ class GeneratorBase : public NamesInterface { typename std::enable_if::value>::type * = nullptr> std::vector build_input(size_t i, const T &arg) { auto *in = param_info().inputs().at(i); - check_input_kind(in, Internal::IOKind::Scalar); + check_input_kind(in, Internal::ArgInfoKind::Scalar); check_input_is_singular(in); // We must use an explicit Expr() ctor to preserve the type Expr e(arg); @@ -3602,7 +3782,7 @@ class GeneratorBase : public NamesInterface { typename std::enable_if::value>::type * = nullptr> std::vector build_input(size_t i, const std::vector &arg) { auto *in = param_info().inputs().at(i); - check_input_kind(in, Internal::IOKind::Scalar); + check_input_kind(in, Internal::ArgInfoKind::Scalar); check_input_is_array(in); std::vector siv; siv.reserve(arg.size()); @@ -3620,7 +3800,44 @@ class GeneratorBase : public NamesInterface { return {build_input(Indices, std::get(t))...}; } + // Note that this deliberately ignores inputs/outputs with multiple array values + // (ie, one name per input or output, regardless of array_size()) + template + static void get_arguments(std::vector &args, ArgInfoDirection dir, const T &t) { + for (auto *e : t) { + args.push_back({e->name(), + dir, + e->kind(), + e->gio_types_defined() ? e->gio_types() : std::vector{}, + e->dims_defined() ? e->dims() : 0}); + } + } + public: + // AbstractGenerator methods + std::string name() override; + GeneratorContext context() const override; + std::vector arginfos() override; + + void set_generatorparam_value(const std::string &name, const std::string &value) override; + void set_generatorparam_value(const std::string &name, const LoopLevel &loop_level) override; + + std::vector input_parameter(const std::string &name) override; + std::vector output_func(const std::string &name) override; + +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE + ExternsMap external_code_map() override; +#endif + + // This is overridden in the concrete Generator<> subclass. + // Pipeline build_pipeline() override; + + void bind_input(const std::string &name, const std::vector &v) override; + void bind_input(const std::string &name, const std::vector &v) override; + void bind_input(const std::string &name, const std::vector &v) override; + + bool emit_cpp_stub(const std::string &stub_file_path) override; + GeneratorBase(const GeneratorBase &) = delete; GeneratorBase &operator=(const GeneratorBase &) = delete; GeneratorBase(GeneratorBase &&that) = delete; @@ -3632,10 +3849,10 @@ class GeneratorRegistry { static void register_factory(const std::string &name, GeneratorFactory generator_factory); static void unregister_factory(const std::string &name); static std::vector enumerate(); - // Note that this method will never return null: - // if it cannot return a valid Generator, it should assert-fail. - static std::unique_ptr create(const std::string &name, - const Halide::GeneratorContext &context); + // This method returns nullptr if it cannot return a valid Generator; + // the caller is responsible for checking the result. + static AbstractGeneratorPtr create(const std::string &name, + const Halide::GeneratorContext &context); private: using GeneratorFactoryMap = std::map; @@ -3684,14 +3901,6 @@ class Generator : public Internal::GeneratorBase { template void apply(const Args &...args) { -#ifdef HALIDE_ALLOW_GENERATOR_BUILD_METHOD -#ifndef _MSC_VER - // VS2015 apparently has some SFINAE issues, so this can inappropriately - // trigger there. (We'll still fail when generate() is called, just - // with a less-helpful error message.) - static_assert(has_generate_method::value, "apply() is not supported for old-style Generators."); -#endif -#endif call_configure(); set_inputs(args...); call_generate(); @@ -3734,118 +3943,6 @@ class Generator : public Internal::GeneratorBase { template struct has_schedule_method().schedule())>::type> : std::true_type {}; -#ifdef HALIDE_ALLOW_GENERATOR_BUILD_METHOD - // Implementations for build_pipeline_impl(), specialized on whether we - // have build() or generate()/schedule() methods. - - // MSVC apparently has some weirdness with the usual sfinae tricks - // for detecting method-shaped things, so we can't actually use - // the helpers above outside of static_assert. Instead we make as - // many overloads as we can exist, and then use C++'s preference - // for treating a 0 as an int rather than a double to choose one - // of them. - template::value>::type * = nullptr> - HALIDE_ATTRIBUTE_DEPRECATED("The build() method is deprecated for Halide Generators and will be removed entirely in future versions of Halide. Please use a generate() method with Output<> members instead.") - Pipeline build_pipeline_impl(double) { - static_assert(!has_configure_method::value, "The configure() method is ignored if you define a build() method; use generate() instead."); - static_assert(!has_schedule_method::value, "The schedule() method is ignored if you define a build() method; use generate() instead."); - - user_warning << "The build() method is deprecated for Halide Generators and will be removed entirely in future versions of Halide. " - << "Please use a generate() method with Output<> members instead.\n"; - - pre_build(); - Pipeline p = ((T *)this)->build(); - post_build(); - return p; - } - - template().generate())> - Pipeline build_pipeline_impl(int) { - // No: configure() must be called prior to this - // (and in fact, prior to calling set_inputs). - // - // ((T *)this)->call_configure_impl(0, 0); - - ((T *)this)->call_generate_impl(0); - ((T *)this)->call_schedule_impl(0, 0); - return get_pipeline(); - } - - // Implementations for call_configure_impl(), specialized on whether we - // have build() or configure()/generate()/schedule() methods. - - void call_configure_impl(double, double) { - pre_configure(); - // Called as a side effect for build()-method Generators; quietly do nothing - // (except for pre_configure(), to advance the phase). - post_configure(); - } - - template().generate())> - void call_configure_impl(double, int) { - // Generator has a generate() method but no configure() method. This is ok. Just advance the phase. - pre_configure(); - static_assert(!has_configure_method::value, "Did not expect a configure method here."); - post_configure(); - } - - template().generate()), - typename = decltype(std::declval().configure())> - void call_configure_impl(int, int) { - T *t = (T *)this; - static_assert(std::is_voidconfigure())>::value, "configure() must return void"); - pre_configure(); - t->configure(); - post_configure(); - } - - // Implementations for call_generate_impl(), specialized on whether we - // have build() or configure()/generate()/schedule() methods. - - void call_generate_impl(double) { - user_error << "Unimplemented"; - } - - template().generate())> - void call_generate_impl(int) { - T *t = (T *)this; - static_assert(std::is_voidgenerate())>::value, "generate() must return void"); - pre_generate(); - t->generate(); - post_generate(); - } - - // Implementations for call_schedule_impl(), specialized on whether we - // have build() or configure()generate()/schedule() methods. - - void call_schedule_impl(double, double) { - user_error << "Unimplemented"; - } - - template().generate())> - void call_schedule_impl(double, int) { - // Generator has a generate() method but no schedule() method. This is ok. Just advance the phase. - pre_schedule(); - post_schedule(); - } - - template().generate()), - typename = decltype(std::declval().schedule())> - void call_schedule_impl(int, int) { - T *t = (T *)this; - static_assert(std::is_voidschedule())>::value, "schedule() must return void"); - pre_schedule(); - t->schedule(); - post_schedule(); - } -#else Pipeline build_pipeline_impl() { T *t = (T *)this; // No: configure() must be called prior to this @@ -3886,27 +3983,10 @@ class Generator : public Internal::GeneratorBase { } post_schedule(); } -#endif protected: -#ifdef HALIDE_ALLOW_GENERATOR_BUILD_METHOD - Pipeline build_pipeline() override { - return this->build_pipeline_impl(0); - } - - void call_configure() override { - this->call_configure_impl(0, 0); - } - - void call_generate() override { - this->call_generate_impl(0); - } - - void call_schedule() override { - this->call_schedule_impl(0, 0); - } -#else Pipeline build_pipeline() override { + ensure_configure_has_been_called(); return this->build_pipeline_impl(); } @@ -3921,7 +4001,7 @@ class Generator : public Internal::GeneratorBase { void call_schedule() override { this->call_schedule_impl(); } -#endif + private: friend void ::Halide::Internal::generator_test(); friend void ::Halide::Internal::generator_test(); @@ -3941,63 +4021,97 @@ class RegisterGenerator { RegisterGenerator(const char *registered_name, GeneratorFactory generator_factory); }; -class GeneratorStub : public NamesInterface { -public: - GeneratorStub(const GeneratorContext &context, - const GeneratorFactory &generator_factory); - - GeneratorStub(const GeneratorContext &context, - const GeneratorFactory &generator_factory, - const GeneratorParamsMap &generator_params, - const std::vector> &inputs); - std::vector> generate(const GeneratorParamsMap &generator_params, - const std::vector> &inputs); - - // Output(s) - std::vector get_outputs(const std::string &n) const { - return generator->get_outputs(n); - } - - template - std::vector get_output_buffers(const std::string &n) const { - auto v = generator->get_outputs(n); - std::vector result; - for (auto &o : v) { - result.push_back(T2(o, generator)); - } - return result; - } - - static std::vector to_stub_input_vector(const Expr &e) { - return {StubInput(e)}; - } +// ----------------------------- - static std::vector to_stub_input_vector(const Func &f) { - return {StubInput(f)}; - } +/** ExecuteGeneratorArgs is the set of arguments to execute_generator(). + */ +struct ExecuteGeneratorArgs { + // Output directory for all files generated. Must not be empty. + std::string output_dir; + + // Type(s) of outputs to produce. Must not be empty. + std::set output_types; + + // Target(s) to use when generating. Must not be empty. + // If list contains multiple entries, a multitarget output will be produced. + std::vector targets; + + // When generating multitarget output, use these as the suffixes for each Target + // specified by the targets field. If empty, the canonical string form of + // each Target will be used. If nonempty, it must be the same length as the + // targets vector. + std::vector suffixes; + + // Name of the generator to execute (or empty if none, e.g. if generating a runtime) + // Must be one recognized by the specified GeneratorFactoryProvider. + std::string generator_name; + + // Name to use for the generated function. May include C++ namespaces, + // e.g. "HalideTest::AnotherNamespace::cxx_mangling". If empty, use `generator_name`. + std::string function_name; + + // Base filename for all outputs (differentated by file extension). + // If empty, use `function_name` (ignoring any C++ namespaces). + std::string file_base_name; + + // The name of a standalone runtime to generate. Only honors EMIT_OPTIONS 'o' + // and 'static_library'. When multiple targets are specified, it picks a + // runtime that is compatible with all of the targets, or fails if it cannot + // find one. Flags across all of the targets that do not affect runtime code + // generation, such as `no_asserts` and `no_runtime`, are ignored. + std::string runtime_name; + + // The mode in which to build the Generator. + enum BuildMode { + // Build it as written. + Default, + + // Build a version suitable for using for gradient descent calculation. + Gradient + } build_mode = Default; + + // The fn that will produce Generator(s) from the name specified. + // (Note that `generator_name` is the only value that will ever be passed + // for name here; it is provided for ease of interoperation with existing code.) + // + // If null, the default global registry of Generators will be used. + using CreateGeneratorFn = std::function; + CreateGeneratorFn create_generator = nullptr; - template - static std::vector to_stub_input_vector(const StubInputBuffer &b) { - return {StubInput(b)}; - } + // Values to substitute for GeneratorParams in the selected Generator. + // Should not contain `target`. + // + // If any of the generator param names specified in this map are unknown + // to the Generator created, an error will occur. + GeneratorParamsMap generator_params; - template - static std::vector to_stub_input_vector(const std::vector &v) { - std::vector r; - std::copy(v.begin(), v.end(), std::back_inserter(r)); - return r; - } + // Compiler Logger to use, for diagnostic work. If null, don't do any logging. + CompilerLoggerFactory compiler_logger_factory = nullptr; +}; - struct Names { - std::vector generator_params, inputs, outputs; - }; - Names get_names() const; +/** + * Execute a Generator for AOT compilation -- this provides the implementation of + * the command-line Generator interface `generate_filter_main()`, but with a structured + * API that is more suitable for calling directly from code (vs command line). + */ +void execute_generator(const ExecuteGeneratorArgs &args); - std::shared_ptr generator; -}; +// ----------------------------- } // namespace Internal +/** Create a Generator from the currently-registered Generators, use it to create a Callable. + * Any GeneratorParams specified will be applied to the Generator before compilation. + * If the name isn't registered, assert-fail. */ +// @{ +Callable create_callable_from_generator(const GeneratorContext &context, + const std::string &name, + const GeneratorParamsMap &generator_params = {}); +Callable create_callable_from_generator(const Target &target, + const std::string &name, + const GeneratorParamsMap &generator_params = {}); +// @} + } // namespace Halide // Define this namespace at global scope so that anonymous namespaces won't @@ -4011,8 +4125,8 @@ struct halide_global_ns; namespace halide_register_generator { \ struct halide_global_ns; \ namespace GEN_REGISTRY_NAME##_ns { \ - std::unique_ptr factory(const Halide::GeneratorContext &context); \ - std::unique_ptr factory(const Halide::GeneratorContext &context) { \ + std::unique_ptr factory(const Halide::GeneratorContext &context); \ + std::unique_ptr factory(const Halide::GeneratorContext &context) { \ using GenType = std::remove_pointer::type; /* NOLINT(bugprone-macro-parentheses) */ \ return GenType::create(context, #GEN_REGISTRY_NAME, #FULLY_QUALIFIED_STUB_NAME); \ } \ @@ -4073,13 +4187,13 @@ struct halide_global_ns; namespace halide_register_generator { \ struct halide_global_ns; \ namespace ORIGINAL_REGISTRY_NAME##_ns { \ - std::unique_ptr factory(const Halide::GeneratorContext &context); \ + std::unique_ptr factory(const Halide::GeneratorContext &context); \ } \ namespace GEN_REGISTRY_NAME##_ns { \ - std::unique_ptr factory(const Halide::GeneratorContext &context); \ - std::unique_ptr factory(const Halide::GeneratorContext &context) { \ + std::unique_ptr factory(const Halide::GeneratorContext &context) { \ auto g = ORIGINAL_REGISTRY_NAME##_ns::factory(context); \ - g->set_generator_param_values(__VA_ARGS__); \ + const Halide::GeneratorParamsMap m = __VA_ARGS__; \ + g->set_generatorparam_values(m); \ return g; \ } \ } \ diff --git a/src/HexagonOffload.cpp b/src/HexagonOffload.cpp index 8ffd1d0c2e4d..1f1ce00b3f1e 100644 --- a/src/HexagonOffload.cpp +++ b/src/HexagonOffload.cpp @@ -286,7 +286,7 @@ void do_reloc(char *addr, uint32_t mask, uintptr_t val, bool is_signed, bool ver // Pull out the subinstructions. They're the low 13 // bits of each half-word. uint32_t hi = (inst >> 16) & ((1 << 13) - 1); - //uint32_t lo = inst & ((1 << 13) - 1); + // uint32_t lo = inst & ((1 << 13) - 1); // We only understand the ones where hi starts with 010 internal_assert((hi >> 10) == 2); @@ -989,7 +989,6 @@ Stmt inject_hexagon_rpc(Stmt s, const Target &host_target, Target::HVX_v62, Target::HVX_v65, Target::HVX_v66, - Target::DisableLLVMLoopOpt, }; for (Target::Feature i : shared_features) { if (host_target.has_feature(i)) { diff --git a/src/HexagonOptimize.cpp b/src/HexagonOptimize.cpp index 3749a9434b42..1cdc525398df 100644 --- a/src/HexagonOptimize.cpp +++ b/src/HexagonOptimize.cpp @@ -797,42 +797,6 @@ class OptimizePatterns : public IRMutator { // Halving unsigned subtract. {"halide.hexagon.navg.vub.vub", i8(widening_sub(wild_u8x, wild_u8x) >> 1)}, - // Saturating narrowing casts with rounding - {"halide.hexagon.trunc_satub_rnd.vh", u8_sat(rounding_shift_right(wild_i16x, 8)), Pattern::DeinterleaveOp0}, - {"halide.hexagon.trunc_satb_rnd.vh", i8_sat(rounding_shift_right(wild_i16x, 8)), Pattern::DeinterleaveOp0}, - {"halide.hexagon.trunc_satub_rnd.vuh", u8_sat(rounding_shift_right(wild_u16x, 8)), Pattern::DeinterleaveOp0}, - {"halide.hexagon.trunc_satuh_rnd.vw", u16_sat(rounding_shift_right(wild_i32x, 16)), Pattern::DeinterleaveOp0}, - {"halide.hexagon.trunc_sath_rnd.vw", i16_sat(rounding_shift_right(wild_i32x, 16)), Pattern::DeinterleaveOp0}, - {"halide.hexagon.trunc_satuh_rnd.vuw", u16_sat(rounding_shift_right(wild_u32x, 16)), Pattern::DeinterleaveOp0}, - - // Saturating narrowing casts with rounding - {"halide.hexagon.trunc_satub_shr_rnd.vh", u8_sat(rounding_shift_right(wild_i16x, wild_u16)), Pattern::DeinterleaveOp0}, - {"halide.hexagon.trunc_satb_shr_rnd.vh", i8_sat(rounding_shift_right(wild_i16x, wild_u16)), Pattern::DeinterleaveOp0}, - {"halide.hexagon.trunc_satub_shr_rnd.vuh", u8_sat(rounding_shift_right(wild_u16x, wild_u16)), Pattern::DeinterleaveOp0 | Pattern::v65orLater}, - {"halide.hexagon.trunc_satuh_shr_rnd.vw", u16_sat(rounding_shift_right(wild_i32x, wild_u32)), Pattern::DeinterleaveOp0}, - {"halide.hexagon.trunc_sath_shr_rnd.vw", i16_sat(rounding_shift_right(wild_i32x, wild_u32)), Pattern::DeinterleaveOp0}, - {"halide.hexagon.trunc_satuh_shr_rnd.vuw", u16_sat(rounding_shift_right(wild_u32x, wild_u32)), Pattern::DeinterleaveOp0}, - - // Saturating narrowing casts - {"halide.hexagon.trunc_satub_shr.vh.uh", u8_sat(wild_i16x >> wild_u16), Pattern::DeinterleaveOp0}, - {"halide.hexagon.trunc_satuh_shr.vw.uw", u16_sat(wild_i32x >> wild_u32), Pattern::DeinterleaveOp0}, - {"halide.hexagon.trunc_sath_shr.vw.uw", i16_sat(wild_i32x >> wild_u32), Pattern::DeinterleaveOp0}, - - // For some of the following narrowing casts, we have the choice of - // non-interleaving or interleaving instructions. Because we don't - // know which one we prefer during pattern matching, we match the - // non-interleaving versions for now and replace them with the - // instructions that interleave later if it makes sense. - - // Saturating narrowing casts. These may interleave later with trunc_sat. - {"halide.hexagon.pack_satub.vh", u8_sat(wild_i16x)}, - {"halide.hexagon.pack_satuh.vw", u16_sat(wild_i32x)}, - {"halide.hexagon.pack_satb.vh", i8_sat(wild_i16x)}, - {"halide.hexagon.pack_sath.vw", i16_sat(wild_i32x)}, - - // We don't have a vpack equivalent to this one, so we match it directly. - {"halide.hexagon.trunc_satuh.vuw", u16_sat(wild_u32x), Pattern::DeinterleaveOp0}, - // Narrowing casts. These may interleave later with trunclo. {"halide.hexagon.packhi.vh", u8(wild_u16x >> 8)}, {"halide.hexagon.packhi.vh", u8(wild_i16x >> 8)}, @@ -872,12 +836,6 @@ class OptimizePatterns : public IRMutator { // fall through to LLVM, which will generate large unoptimized // shuffles. static const vector> cast_rewrites = { - // Saturating narrowing - {u8_sat(wild_u32x), u8_sat(u16_sat(wild_u32x))}, - {u8_sat(wild_i32x), u8_sat(i16_sat(wild_i32x))}, - {i8_sat(wild_u32x), i8_sat(u16_sat(wild_u32x))}, - {i8_sat(wild_i32x), i8_sat(i16_sat(wild_i32x))}, - // Narrowing {u8(wild_u32x), u8(u16(wild_u32x))}, {u8(wild_i32x), u8(i16(wild_i32x))}, @@ -942,6 +900,42 @@ class OptimizePatterns : public IRMutator { } static const vector calls = { + // Saturating narrowing casts with rounding + {"halide.hexagon.trunc_satub_rnd.vh", u8_sat(rounding_shift_right(wild_i16x, 8)), Pattern::DeinterleaveOp0}, + {"halide.hexagon.trunc_satb_rnd.vh", i8_sat(rounding_shift_right(wild_i16x, 8)), Pattern::DeinterleaveOp0}, + {"halide.hexagon.trunc_satub_rnd.vuh", u8_sat(rounding_shift_right(wild_u16x, 8)), Pattern::DeinterleaveOp0}, + {"halide.hexagon.trunc_satuh_rnd.vw", u16_sat(rounding_shift_right(wild_i32x, 16)), Pattern::DeinterleaveOp0}, + {"halide.hexagon.trunc_sath_rnd.vw", i16_sat(rounding_shift_right(wild_i32x, 16)), Pattern::DeinterleaveOp0}, + {"halide.hexagon.trunc_satuh_rnd.vuw", u16_sat(rounding_shift_right(wild_u32x, 16)), Pattern::DeinterleaveOp0}, + + // Saturating narrowing casts with rounding + {"halide.hexagon.trunc_satub_shr_rnd.vh", u8_sat(rounding_shift_right(wild_i16x, wild_u16)), Pattern::DeinterleaveOp0}, + {"halide.hexagon.trunc_satb_shr_rnd.vh", i8_sat(rounding_shift_right(wild_i16x, wild_u16)), Pattern::DeinterleaveOp0}, + {"halide.hexagon.trunc_satub_shr_rnd.vuh", u8_sat(rounding_shift_right(wild_u16x, wild_u16)), Pattern::DeinterleaveOp0 | Pattern::v65orLater}, + {"halide.hexagon.trunc_satuh_shr_rnd.vw", u16_sat(rounding_shift_right(wild_i32x, wild_u32)), Pattern::DeinterleaveOp0}, + {"halide.hexagon.trunc_sath_shr_rnd.vw", i16_sat(rounding_shift_right(wild_i32x, wild_u32)), Pattern::DeinterleaveOp0}, + {"halide.hexagon.trunc_satuh_shr_rnd.vuw", u16_sat(rounding_shift_right(wild_u32x, wild_u32)), Pattern::DeinterleaveOp0}, + + // Saturating narrowing casts + {"halide.hexagon.trunc_satub_shr.vh.uh", u8_sat(wild_i16x >> wild_u16), Pattern::DeinterleaveOp0}, + {"halide.hexagon.trunc_satuh_shr.vw.uw", u16_sat(wild_i32x >> wild_u32), Pattern::DeinterleaveOp0}, + {"halide.hexagon.trunc_sath_shr.vw.uw", i16_sat(wild_i32x >> wild_u32), Pattern::DeinterleaveOp0}, + + // For some of the following narrowing casts, we have the choice of + // non-interleaving or interleaving instructions. Because we don't + // know which one we prefer during pattern matching, we match the + // non-interleaving versions for now and replace them with the + // instructions that interleave later if it makes sense. + + // Saturating narrowing casts. These may interleave later with trunc_sat. + {"halide.hexagon.pack_satub.vh", u8_sat(wild_i16x)}, + {"halide.hexagon.pack_satuh.vw", u16_sat(wild_i32x)}, + {"halide.hexagon.pack_satb.vh", i8_sat(wild_i16x)}, + {"halide.hexagon.pack_sath.vw", i16_sat(wild_i32x)}, + + // We don't have a vpack equivalent to this one, so we match it directly. + {"halide.hexagon.trunc_satuh.vuw", u16_sat(wild_u32x), Pattern::DeinterleaveOp0}, + // Multiply keep high half. {"halide.hexagon.trunc_mpy.vw.vw", mul_shift_right(wild_i32x, wild_i32x, 32)}, @@ -980,11 +974,34 @@ class OptimizePatterns : public IRMutator { {"halide.hexagon.mpy.vh.vuh", widening_mul(wild_u16x, wild_i16x), Pattern::InterleaveResult | Pattern::SwapOps01}, }; + // To hit more of the patterns we want, rewrite "double casts" + // as two stage casts. This also avoids letting vector casts + // fall through to LLVM, which will generate large unoptimized + // shuffles. + static const vector> cast_rewrites = { + // Saturating narrowing + {u8_sat(wild_u32x), u8_sat(u16_sat(wild_u32x))}, + {u8_sat(wild_i32x), u8_sat(i16_sat(wild_i32x))}, + {i8_sat(wild_u32x), i8_sat(u16_sat(wild_u32x))}, + {i8_sat(wild_i32x), i8_sat(i16_sat(wild_i32x))}, + }; + if (op->type.is_vector()) { Expr new_expr = apply_patterns(op, calls, target, this); if (!new_expr.same_as(op)) { return new_expr; } + + // If we didn't find a pattern, try using one of the + // rewrites above. + vector matches; + for (const auto &i : cast_rewrites) { + if (expr_match(i.first, op, matches)) { + Expr replacement = substitute("*", matches[0], with_lanes(i.second, op->type.lanes())); + debug(3) << "rewriting cast to: " << replacement << " from " << Expr(op) << "\n"; + return mutate(replacement); + } + } } if (op->is_intrinsic(Call::lerp)) { diff --git a/src/IR.cpp b/src/IR.cpp index d74f51d4d090..105472ab68d6 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -20,6 +20,22 @@ Expr Cast::make(Type t, Expr v) { return node; } +Expr Reinterpret::make(Type t, Expr v) { + user_assert(v.defined()) << "reinterpret of undefined Expr\n"; + int from_bits = v.type().bits() * v.type().lanes(); + int to_bits = t.bits() * t.lanes(); + user_assert(from_bits == to_bits) + << "Reinterpret cast from type " << v.type() + << " which has " << from_bits + << " bits, to type " << t + << " which has " << to_bits << " bits\n"; + + Reinterpret *node = new Reinterpret; + node->type = t; + node->value = std::move(v); + return node; +} + Expr Add::make(Expr a, Expr b) { internal_assert(a.defined()) << "Add of undefined\n"; internal_assert(b.defined()) << "Add of undefined\n"; @@ -595,12 +611,14 @@ const char *const intrinsic_op_names[] = { "bundle", "call_cached_indirect_function", "cast_mask", + "concat_bits", "count_leading_zeros", "count_trailing_zeros", "debug_to_file", "declare_box_touched", "div_round_to_zero", "dynamic_shuffle", + "extract_bits", "extract_mask_element", "get_user_context", "gpu_thread_barrier", @@ -628,18 +646,17 @@ const char *const intrinsic_op_names[] = { "promise_clamped", "random", "register_destructor", - "reinterpret", "require", "require_mask", "return_second", "rewrite_buffer", "rounding_halving_add", - "rounding_halving_sub", "rounding_mul_shift_right", "rounding_shift_left", "rounding_shift_right", "saturating_add", "saturating_sub", + "saturating_cast", "scatter_gather", "select_mask", "shift_left", @@ -971,6 +988,10 @@ void ExprNode::accept(IRVisitor *v) const { v->visit((const Cast *)this); } template<> +void ExprNode::accept(IRVisitor *v) const { + v->visit((const Reinterpret *)this); +} +template<> void ExprNode::accept(IRVisitor *v) const { v->visit((const Variable *)this); } @@ -1156,6 +1177,10 @@ Expr ExprNode::mutate_expr(IRMutator *v) const { return v->visit((const Cast *)this); } template<> +Expr ExprNode::mutate_expr(IRMutator *v) const { + return v->visit((const Reinterpret *)this); +} +template<> Expr ExprNode::mutate_expr(IRMutator *v) const { return v->visit((const Variable *)this); } diff --git a/src/IR.h b/src/IR.h index a0311f3c86a4..ff92f38d4107 100644 --- a/src/IR.h +++ b/src/IR.h @@ -34,6 +34,16 @@ struct Cast : public ExprNode { static const IRNodeType _node_type = IRNodeType::Cast; }; +/** Reinterpret value as another type, without affecting any of the bits + * (on little-endian systems). */ +struct Reinterpret : public ExprNode { + Expr value; + + static Expr make(Type t, Expr v); + + static const IRNodeType _node_type = IRNodeType::Reinterpret; +}; + /** The sum of two expressions */ struct Add : public ExprNode { Expr a, b; @@ -501,15 +511,26 @@ struct Call : public ExprNode { bitwise_or, bitwise_xor, bool_to_mask, - bundle, // Bundle multiple exprs together temporarily for analysis (e.g. CSE) + + // Bundle multiple exprs together temporarily for analysis (e.g. CSE) + bundle, call_cached_indirect_function, cast_mask, + + // Concatenate bits of the args, with least significant bits as the + // first arg (i.e. little-endian) + concat_bits, count_leading_zeros, count_trailing_zeros, debug_to_file, declare_box_touched, div_round_to_zero, dynamic_shuffle, + + // Extract some contiguous slice of bits from the argument starting at + // the nth bit, counting from the least significant bit, with the number + // of bits determined by the return type. + extract_bits, extract_mask_element, get_user_context, gpu_thread_barrier, @@ -537,25 +558,26 @@ struct Call : public ExprNode { promise_clamped, random, register_destructor, - reinterpret, require, require_mask, return_second, rewrite_buffer, rounding_halving_add, - rounding_halving_sub, rounding_mul_shift_right, rounding_shift_left, rounding_shift_right, saturating_add, saturating_sub, + saturating_cast, scatter_gather, select_mask, shift_left, shift_right, signed_integer_overflow, size_of_halide_buffer_t, - sorted_avg, // Compute (arg[0] + arg[1]) / 2, assuming arg[0] < arg[1]. + + // Compute (arg[0] + arg[1]) / 2, assuming arg[0] < arg[1]. + sorted_avg, strict_float, stringify, undef, diff --git a/src/IREquality.cpp b/src/IREquality.cpp index 9e87b6950553..20cb616d2c32 100644 --- a/src/IREquality.cpp +++ b/src/IREquality.cpp @@ -57,6 +57,7 @@ class IRComparer : public IRVisitor { void visit(const FloatImm *) override; void visit(const StringImm *) override; void visit(const Cast *) override; + void visit(const Reinterpret *) override; void visit(const Variable *) override; void visit(const Add *) override; void visit(const Sub *) override; @@ -354,6 +355,10 @@ void IRComparer::visit(const Cast *op) { compare_expr(expr.as()->value, op->value); } +void IRComparer::visit(const Reinterpret *op) { + compare_expr(expr.as()->value, op->value); +} + void IRComparer::visit(const Variable *op) { const Variable *e = expr.as(); compare_names(e->name, op->name); diff --git a/src/IRMatch.cpp b/src/IRMatch.cpp index e769ae65d038..8416d223ffc4 100644 --- a/src/IRMatch.cpp +++ b/src/IRMatch.cpp @@ -116,6 +116,16 @@ class IRMatch : public IRVisitor { } } + void visit(const Reinterpret *op) override { + const Reinterpret *e = expr.as(); + if (result && e && types_match(op->type, e->type)) { + expr = e->value; + op->value.accept(this); + } else { + result = false; + } + } + void visit(const Variable *op) override { if (!result) { return; @@ -374,6 +384,16 @@ class WithLanes : public IRMutator { } } + Expr visit(const Call *op) override { + if (op->is_intrinsic() && (op->type.lanes() != lanes)) { + auto new_args = mutate_with_changes(op->args).first; + return Call::make(with_lanes(op->type), op->name, new_args, op->call_type, + op->func, op->value_index, op->image, op->param); + } else { + return IRMutator::visit(op); + } + } + public: WithLanes(int lanes) : lanes(lanes) { @@ -432,6 +452,11 @@ bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept { // that the types of the values match, so use equal rather // than equal_helper. return equal(((const Cast &)a).value, ((const Cast &)b).value); + case IRNodeType::Reinterpret: + // While we know a and b have matching type, we don't know + // that the types of the values match, so use equal rather + // than equal_helper. + return equal(((const Reinterpret &)a).value, ((const Reinterpret &)b).value); case IRNodeType::Variable: return ((const Variable &)a).name == ((const Variable &)b).name; case IRNodeType::Add: diff --git a/src/IRMatch.h b/src/IRMatch.h index 43ce9508aa9e..0de1c21cee44 100644 --- a/src/IRMatch.h +++ b/src/IRMatch.h @@ -1358,6 +1358,10 @@ struct Intrin { struct pattern_tag {}; Call::IntrinsicOp intrin; std::tuple args; + // The type of the output of the intrinsic node. + // Only necessary in cases where it can't be inferred + // from the input types (e.g. saturating_cast). + Type optional_type_hint; static constexpr uint32_t binds = bitwise_or_reduce((bindings::mask)...); @@ -1385,7 +1389,9 @@ struct Intrin { return false; } const Call &c = (const Call &)e; - return (c.is_intrinsic(intrin) && match_args<0, bound>(0, c, state)); + return (c.is_intrinsic(intrin) && + ((optional_type_hint == Type()) || optional_type_hint == e.type) && + match_args<0, bound>(0, c, state)); } template(args).make(state, type_hint); @@ -1437,8 +1445,6 @@ struct Intrin { return halving_sub(arg0, arg1); } else if (intrin == Call::rounding_halving_add) { return rounding_halving_add(arg0, arg1); - } else if (intrin == Call::rounding_halving_sub) { - return rounding_halving_sub(arg0, arg1); } else if (intrin == Call::shift_left) { return arg0 << arg1; } else if (intrin == Call::shift_right) { @@ -1543,6 +1549,12 @@ template auto saturating_sub(A &&a, B &&b) noexcept -> Intrin { return {Call::saturating_sub, pattern_arg(a), pattern_arg(b)}; } +template +auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin { + Intrin p = {Call::saturating_cast, pattern_arg(a)}; + p.optional_type_hint = t; + return p; +} template auto halving_add(A &&a, B &&b) noexcept -> Intrin { return {Call::halving_add, pattern_arg(a), pattern_arg(b)}; @@ -1556,10 +1568,6 @@ auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin -auto rounding_halving_sub(A &&a, B &&b) noexcept -> Intrin { - return {Call::rounding_halving_sub, pattern_arg(a), pattern_arg(b)}; -} -template auto shift_left(A &&a, B &&b) noexcept -> Intrin { return {Call::shift_left, pattern_arg(a), pattern_arg(b)}; } diff --git a/src/IRMutator.cpp b/src/IRMutator.cpp index 5272f3051577..005937a17008 100644 --- a/src/IRMutator.cpp +++ b/src/IRMutator.cpp @@ -37,6 +37,14 @@ Expr IRMutator::visit(const Cast *op) { return Cast::make(op->type, std::move(value)); } +Expr IRMutator::visit(const Reinterpret *op) { + Expr value = mutate(op->value); + if (value.same_as(op->value)) { + return op; + } + return Reinterpret::make(op->type, std::move(value)); +} + namespace { template Expr mutate_binary_operator(IRMutator *mutator, const T *op) { diff --git a/src/IRMutator.h b/src/IRMutator.h index 04613a495930..c7a1984269d3 100644 --- a/src/IRMutator.h +++ b/src/IRMutator.h @@ -56,6 +56,7 @@ class IRMutator { virtual Expr visit(const FloatImm *); virtual Expr visit(const StringImm *); virtual Expr visit(const Cast *); + virtual Expr visit(const Reinterpret *); virtual Expr visit(const Variable *); virtual Expr visit(const Add *); virtual Expr visit(const Sub *); diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 697348ecd265..7d9e0bab5473 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -62,6 +62,10 @@ Expr stringify(const std::vector &args) { } Expr combine_strings(const std::vector &args) { + if (args.empty()) { + return Expr(""); + } + // Insert spaces between each expr. std::vector strings(args.size() * 2); for (size_t i = 0; i < args.size(); i++) { @@ -73,6 +77,20 @@ Expr combine_strings(const std::vector &args) { } } + // Now combine all adjacent string literals, which is + // useful to reduce emitted code size when printing + size_t i = 0; + while (i < strings.size() - 1) { + const auto *cur_str = strings[i].as(); + const auto *next_str = strings[i + 1].as(); + if (cur_str && next_str) { + strings[i] = Internal::StringImm::make(cur_str->value + next_str->value); + strings.erase(strings.begin() + i + 1); + continue; + } + i++; + } + return stringify(strings); } @@ -1231,13 +1249,6 @@ Expr halving_sub(Expr a, Expr b) { return Call::make(result_type, Call::halving_sub, {std::move(a), std::move(b)}, Call::PureIntrinsic); } -Expr rounding_halving_sub(Expr a, Expr b) { - user_assert(a.defined() && b.defined()) << "rounding_halving_sub of undefined Expr\n"; - match_types(a, b); - Type result_type = a.type(); - return Call::make(result_type, Call::rounding_halving_sub, {std::move(a), std::move(b)}, Call::PureIntrinsic); -} - Expr mul_shift_right(Expr a, Expr b, Expr q) { user_assert(a.defined() && b.defined() && q.defined()) << "mul_shift_right of undefined Expr\n"; user_assert(q.type().is_uint()) << "mul_shift_right shift must be unsigned\n"; @@ -1419,40 +1430,7 @@ Expr require(Expr condition, const std::vector &args) { } Expr saturating_cast(Type t, Expr e) { - // For float to float, guarantee infinities are always pinned to range. - if (t.is_float() && e.type().is_float()) { - if (t.bits() < e.type().bits()) { - e = cast(t, clamp(std::move(e), t.min(), t.max())); - } else { - e = clamp(cast(t, std::move(e)), t.min(), t.max()); - } - } else if (e.type() != t) { - // Limits for Int(2^n) or UInt(2^n) are not exactly representable in Float(2^n) - if (e.type().is_float() && !t.is_float() && t.bits() >= e.type().bits()) { - e = max(std::move(e), t.min()); // min values turn out to be always representable - - // This line depends on t.max() rounding upward, which should always - // be the case as it is one less than a representable value, thus - // the one larger is always the closest. - e = select(e >= cast(e.type(), t.max()), t.max(), cast(t, e)); - } else { - Expr min_bound; - if (!e.type().is_uint()) { - min_bound = lossless_cast(e.type(), t.min()); - } - Expr max_bound = lossless_cast(e.type(), t.max()); - - if (min_bound.defined() && max_bound.defined()) { - e = clamp(std::move(e), min_bound, max_bound); - } else if (min_bound.defined()) { - e = max(std::move(e), min_bound); - } else if (max_bound.defined()) { - e = min(std::move(e), max_bound); - } - e = cast(t, std::move(e)); - } - } - return e; + return Internal::Call::make(t, Internal::Call::saturating_cast, {std::move(e)}, Internal::Call::PureIntrinsic); } Expr select(Expr condition, Expr true_value, Expr false_value) { @@ -2350,15 +2328,7 @@ Expr fract(const Expr &x) { } Expr reinterpret(Type t, Expr e) { - user_assert(e.defined()) << "reinterpret of undefined Expr\n"; - int from_bits = e.type().bits() * e.type().lanes(); - int to_bits = t.bits() * t.lanes(); - user_assert(from_bits == to_bits) - << "Reinterpret cast from type " << e.type() - << " which has " << from_bits - << " bits, to type " << t - << " which has " << to_bits << " bits\n"; - return Internal::Call::make(t, Internal::Call::reinterpret, {std::move(e)}, Internal::Call::PureIntrinsic); + return Internal::Reinterpret::make(t, std::move(e)); } Expr operator&(Expr x, Expr y) { @@ -2651,4 +2621,18 @@ Expr gather(const std::vector &args) { return make_scatter_gather(args); } +Expr extract_bits(Type t, const Expr &e, const Expr &lsb) { + return Internal::Call::make(t, Internal::Call::extract_bits, {e, lsb}, Internal::Call::Intrinsic); +} + +Expr concat_bits(const std::vector &e) { + user_assert(!e.empty()) << "concat_bits requires at least one argument\n"; + user_assert((e.size() & (e.size() - 1)) == 0) << "concat_bits received " << e.size() << " arguments, which is not a power of two.\n"; + Type t = e[0].type(); + for (size_t i = 1; i < e.size(); i++) { + user_assert(e[i].type() == t) << "All arguments to concat_bits must have the same type\n"; + } + return Internal::Call::make(t.with_bits(t.bits() * (int)e.size()), Internal::Call::concat_bits, e, Internal::Call::Intrinsic); +} + } // namespace Halide diff --git a/src/IROperator.h b/src/IROperator.h index b528f41297b6..048998448c75 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -375,8 +375,6 @@ Expr halving_add(Expr a, Expr b); Expr rounding_halving_add(Expr a, Expr b); /** Compute narrow((widen(a) - widen(b)) / 2) */ Expr halving_sub(Expr a, Expr b); -/** Compute narrow((widen(a) - widen(b) + 1) / 2) */ -Expr rounding_halving_sub(Expr a, Expr b); /** Compute saturating_narrow(shift_right(widening_mul(a, b), q)) */ Expr mul_shift_right(Expr a, Expr b, Expr q); @@ -1034,21 +1032,21 @@ Expr round(Expr x); Expr trunc(Expr x); /** Returns true if the argument is a Not a Number (NaN). Requires a - * floating point argument. Vectorizes cleanly. - * Note that the Expr passed in will be evaluated in strict_float mode, - * regardless of whether strict_float mode is enabled in the current Target. */ + * floating point argument. Vectorizes cleanly. + * Note that the Expr passed in will be evaluated in strict_float mode, + * regardless of whether strict_float mode is enabled in the current Target. */ Expr is_nan(Expr x); /** Returns true if the argument is Inf or -Inf. Requires a - * floating point argument. Vectorizes cleanly. - * Note that the Expr passed in will be evaluated in strict_float mode, - * regardless of whether strict_float mode is enabled in the current Target. */ + * floating point argument. Vectorizes cleanly. + * Note that the Expr passed in will be evaluated in strict_float mode, + * regardless of whether strict_float mode is enabled in the current Target. */ Expr is_inf(Expr x); /** Returns true if the argument is a finite value (ie, neither NaN nor Inf). - * Requires a floating point argument. Vectorizes cleanly. - * Note that the Expr passed in will be evaluated in strict_float mode, - * regardless of whether strict_float mode is enabled in the current Target. */ + * Requires a floating point argument. Vectorizes cleanly. + * Note that the Expr passed in will be evaluated in strict_float mode, + * regardless of whether strict_float mode is enabled in the current Target. */ Expr is_finite(Expr x); /** Return the fractional part of a floating-point expression. If the argument @@ -1061,7 +1059,7 @@ Expr reinterpret(Type t, Expr e); template Expr reinterpret(Expr e) { - return reinterpret(type_of(), e); + return reinterpret(type_of(), std::move(e)); } /** Return the bitwise and of two expressions (which need not have the @@ -1563,6 +1561,54 @@ Expr gather(const Expr &e, Args &&...args) { } // @} +/** Extract a contiguous subsequence of the bits of 'e', starting at the bit + * index given by 'lsb', where zero is the least-significant bit, returning a + * value of type 't'. Any out-of-range bits requested are filled with zeros. + * + * extract_bits is especially useful when one wants to load a small vector of a + * wide type, and treat it as a larger vector of a smaller type. For example, + * loading a vector of 32 uint8 values from a uint32 Func can be done as + * follows: +\code +f8(x) = extract_bits(f32(x/4), 8*(x%4)); +f8.align_bounds(x, 4).vectorize(x, 32); +\endcode + * Note that the align_bounds call is critical so that the narrow Exprs are + * aligned to the wider Exprs. This makes the x%4 term collapse to a + * constant. If f8 is an output Func, then constraining the min value of x to be + * a known multiple of four would also be sufficient, e.g. via: +\code +f8.output_buffer().dim(0).set_min(0); +\endcode + * + * See test/correctness/extract_concat_bits.cpp for a complete example. */ +// @{ +Expr extract_bits(Type t, const Expr &e, const Expr &lsb); + +template +Expr extract_bits(const Expr &e, const Expr &lsb) { + return extract_bits(type_of(), e, lsb); +} +// @} + +/** Given a number of Exprs of the same type, concatenate their bits producing a + * single Expr of the same type code of the input but with more bits. The + * number of arguments must be a power of two. + * + * concat_bits is especially useful when one wants to treat a Func containing + * values of a narrow type as a Func containing fewer values of a wider + * type. For example, the following code reinterprets vectors of 32 uint8 values + * as a vector of 8 uint32s: + * +\code +f32(x) = concat_bits({f8(4*x), f8(4*x + 1), f8(4*x + 2), f8(4*x + 3)}); +f32.vectorize(x, 8); +\endcode + * + * See test/correctness/extract_concat_bits.cpp for a complete example. + */ +Expr concat_bits(const std::vector &e); + } // namespace Halide #endif diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index dec52fc28f7f..38f57e46649e 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -519,6 +519,12 @@ void IRPrinter::visit(const Cast *op) { stream << ")"; } +void IRPrinter::visit(const Reinterpret *op) { + stream << "reinterpret<" << op->type << ">("; + print(op->value); + stream << ")"; +} + void IRPrinter::visit(const Variable *op) { if (!known_type.contains(op->name) && (op->type != Int(32))) { diff --git a/src/IRPrinter.h b/src/IRPrinter.h index e0cb4cab5968..666235988cd7 100644 --- a/src/IRPrinter.h +++ b/src/IRPrinter.h @@ -155,6 +155,7 @@ class IRPrinter : public IRVisitor { void visit(const FloatImm *) override; void visit(const StringImm *) override; void visit(const Cast *) override; + void visit(const Reinterpret *) override; void visit(const Variable *) override; void visit(const Add *) override; void visit(const Sub *) override; diff --git a/src/IRVisitor.cpp b/src/IRVisitor.cpp index bde0799bdcee..7f9993987200 100644 --- a/src/IRVisitor.cpp +++ b/src/IRVisitor.cpp @@ -22,6 +22,10 @@ void IRVisitor::visit(const Cast *op) { op->value.accept(this); } +void IRVisitor::visit(const Reinterpret *op) { + op->value.accept(this); +} + void IRVisitor::visit(const Variable *) { } @@ -293,6 +297,10 @@ void IRGraphVisitor::visit(const Cast *op) { include(op->value); } +void IRGraphVisitor::visit(const Reinterpret *op) { + include(op->value); +} + void IRGraphVisitor::visit(const Variable *op) { } diff --git a/src/IRVisitor.h b/src/IRVisitor.h index f29bedc182bc..4e1650ff22be 100644 --- a/src/IRVisitor.h +++ b/src/IRVisitor.h @@ -34,6 +34,7 @@ class IRVisitor { virtual void visit(const FloatImm *); virtual void visit(const StringImm *); virtual void visit(const Cast *); + virtual void visit(const Reinterpret *); virtual void visit(const Variable *); virtual void visit(const Add *); virtual void visit(const Sub *); @@ -104,6 +105,7 @@ class IRGraphVisitor : public IRVisitor { void visit(const FloatImm *) override; void visit(const StringImm *) override; void visit(const Cast *) override; + void visit(const Reinterpret *) override; void visit(const Variable *) override; void visit(const Add *) override; void visit(const Sub *) override; @@ -174,6 +176,8 @@ class VariadicVisitor { return ((T *)this)->visit((const Broadcast *)node, std::forward(args)...); case IRNodeType::Cast: return ((T *)this)->visit((const Cast *)node, std::forward(args)...); + case IRNodeType::Reinterpret: + return ((T *)this)->visit((const Reinterpret *)node, std::forward(args)...); case IRNodeType::Variable: return ((T *)this)->visit((const Variable *)node, std::forward(args)...); case IRNodeType::Add: @@ -258,6 +262,7 @@ class VariadicVisitor { case IRNodeType::StringImm: case IRNodeType::Broadcast: case IRNodeType::Cast: + case IRNodeType::Reinterpret: case IRNodeType::Variable: case IRNodeType::Add: case IRNodeType::Sub: diff --git a/src/ImageParam.cpp b/src/ImageParam.cpp index cda49b501c4c..ca1b5d3922ac 100644 --- a/src/ImageParam.cpp +++ b/src/ImageParam.cpp @@ -33,7 +33,7 @@ Func ImageParam::create_func() const { // Discourage future Funcs from having the same name Internal::unique_name(name()); } - Func f(name() + "_im"); + Func f(param.type(), param.dimensions(), name() + "_im"); f(args) = Internal::Call::make(param, args_expr); return f; } diff --git a/src/InferArguments.cpp b/src/InferArguments.cpp index a9d1cde11e13..d2f55b1fa781 100644 --- a/src/InferArguments.cpp +++ b/src/InferArguments.cpp @@ -182,6 +182,11 @@ class InferArguments : public IRGraphVisitor { } } } + + // It also misses wrappers + for (const auto &p : func.wrappers()) { + Function(p.second).accept(this); + } } void include_parameter(const Parameter &p) { diff --git a/src/JITModule.cpp b/src/JITModule.cpp index acb8be5da8c7..e595613ffb5e 100644 --- a/src/JITModule.cpp +++ b/src/JITModule.cpp @@ -21,6 +21,7 @@ #include "LLVM_Output.h" #include "LLVM_Runtime_Linker.h" #include "Pipeline.h" +#include "WasmExecutor.h" namespace Halide { namespace Internal { @@ -253,10 +254,10 @@ void JITModule::compile_module(std::unique_ptr m, const string &fu debug(2) << "Target triple: " << m->getTargetTriple() << "\n"; string error_string; - string mcpu; - string mattrs; + llvm::for_each(*m, set_function_attributes_from_halide_target_options); + llvm::TargetOptions options; - get_target_options(*m, options, mcpu, mattrs); + get_target_options(*m, options); DataLayout initial_module_data_layout = m->getDataLayout(); string module_name = m->getModuleIdentifier(); @@ -269,11 +270,6 @@ void JITModule::compile_module(std::unique_ptr m, const string &fu engine_builder.setMCJITMemoryManager(std::unique_ptr(memory_manager)); engine_builder.setOptLevel(CodeGenOpt::Aggressive); - if (!mcpu.empty()) { - engine_builder.setMCPU(mcpu); - } - std::vector mattrs_array = {mattrs}; - engine_builder.setMAttrs(mattrs_array); TargetMachine *tm = engine_builder.selectTarget(); internal_assert(tm) << error_string << "\n"; @@ -399,8 +395,8 @@ JITModule::Symbol JITModule::entrypoint_symbol() const { return jit_module->entrypoint; } -int (*JITModule::argv_function() const)(const void **) { - return (int (*)(const void **))jit_module->argv_entrypoint.address; +int (*JITModule::argv_function() const)(const void *const *) { + return (int (*)(const void *const *))jit_module->argv_entrypoint.address; } JITModule::Symbol JITModule::argv_entrypoint_symbol() const { @@ -990,5 +986,139 @@ void JITSharedRuntime::reuse_device_allocations(bool b) { shared_runtimes(MainShared).reuse_device_allocations(b); } +JITCache::JITCache(Target jit_target, + std::vector arguments, + std::map jit_externs, + JITModule jit_module, + WasmModule wasm_module) + : jit_target(jit_target), // clang-tidy complains that this is "trivially copyable" and std::move shouldn't be here, grr + arguments(std::move(arguments)), + jit_externs(std::move(jit_externs)), + jit_module(std::move(jit_module)), + wasm_module(std::move(wasm_module)) { +} + +Target JITCache::get_compiled_jit_target() const { + // This essentially is just a getter for contents->jit_target, + // but also reality-checks that the status of the jit_module and/or wasm_module + // match what we expect. + const bool has_wasm = wasm_module.contents.defined(); + const bool has_native = jit_module.compiled(); + if (jit_target.arch == Target::WebAssembly) { + internal_assert(has_wasm && !has_native); + } else if (!jit_target.has_unknowns()) { + internal_assert(!has_wasm && has_native); + } else { + internal_assert(!has_wasm && !has_native); + } + return jit_target; +} + +int JITCache::call_jit_code(const Target &target, const void *const *args) { +#if defined(__has_feature) +#if __has_feature(memory_sanitizer) + user_warning << "MSAN does not support JIT compilers of any sort, and will report " + "false positives when used in conjunction with the Halide JIT. " + "If you need to test with MSAN enabled, you must use ahead-of-time " + "compilation for Halide code."; +#endif +#endif + if (target.arch == Target::WebAssembly) { + internal_assert(wasm_module.contents.defined()); + return wasm_module.run(args); + } else { + auto argv_wrapper = jit_module.argv_function(); + internal_assert(argv_wrapper != nullptr); + return argv_wrapper(args); + } +} + +void JITCache::finish_profiling(JITUserContext *context) { + // If we're profiling, report runtimes and reset profiler stats. + if (jit_target.has_feature(Target::Profile) || jit_target.has_feature(Target::ProfileByTimer)) { + JITModule::Symbol report_sym = jit_module.find_symbol_by_name("halide_profiler_report"); + JITModule::Symbol reset_sym = jit_module.find_symbol_by_name("halide_profiler_reset"); + if (report_sym.address && reset_sym.address) { + void (*report_fn_ptr)(JITUserContext *) = (void (*)(JITUserContext *))(report_sym.address); + report_fn_ptr(context); + + void (*reset_fn_ptr)() = (void (*)())(reset_sym.address); + reset_fn_ptr(); + } + } +} + +void JITErrorBuffer::concat(const char *message) { + size_t len = strlen(message); + + if (len && message[len - 1] != '\n') { + // Claim some extra space for a newline. + len++; + } + + // Atomically claim some space in the buffer + size_t old_end = end.fetch_add(len); + + if (old_end + len >= MaxBufSize - 1) { + // Out of space + return; + } + + for (size_t i = 0; i < len - 1; i++) { + buf[old_end + i] = message[i]; + } + if (buf[old_end + len - 2] != '\n') { + buf[old_end + len - 1] = '\n'; + } +} + +std::string JITErrorBuffer::str() const { + return std::string(buf, end); +} + +/*static*/ void JITErrorBuffer::handler(JITUserContext *ctx, const char *message) { + if (ctx && ctx->error_buffer) { + ctx->error_buffer->concat(message); + } +} + +JITFuncCallContext::JITFuncCallContext(JITUserContext *context, const JITHandlers &pipeline_handlers) + : context(context) { + custom_error_handler = (context->handlers.custom_error != nullptr || + pipeline_handlers.custom_error != nullptr); + // Hook the error handler if not set + if (!custom_error_handler) { + context->handlers.custom_error = JITErrorBuffer::handler; + } + + // Add the handlers stored in the pipeline for anything else + // not set, then for anything still not set, use the global + // active handlers. + JITSharedRuntime::populate_jit_handlers(context, pipeline_handlers); + context->error_buffer = &error_buffer; + + debug(2) << "custom_print: " << (void *)context->handlers.custom_print << "\n" + << "custom_malloc: " << (void *)context->handlers.custom_malloc << "\n" + << "custom_free: " << (void *)context->handlers.custom_free << "\n" + << "custom_do_task: " << (void *)context->handlers.custom_do_task << "\n" + << "custom_do_par_for: " << (void *)context->handlers.custom_do_par_for << "\n" + << "custom_error: " << (void *)context->handlers.custom_error << "\n" + << "custom_trace: " << (void *)context->handlers.custom_trace << "\n"; +} + +void JITFuncCallContext::finalize(int exit_status) { + // Only report the errors if no custom error handler was installed + if (exit_status && !custom_error_handler) { + std::string output = error_buffer.str(); + if (output.empty()) { + output = ("The pipeline returned exit status " + + std::to_string(exit_status) + + " but halide_error was never called.\n"); + } + halide_runtime_error << output; + error_buffer.end = 0; + } +} + } // namespace Internal } // namespace Halide diff --git a/src/JITModule.h b/src/JITModule.h index ee27fe99216b..467fb82db207 100644 --- a/src/JITModule.h +++ b/src/JITModule.h @@ -8,9 +8,12 @@ #include #include +#include #include "IntrusivePtr.h" +#include "Target.h" #include "Type.h" +#include "WasmExecutor.h" #include "runtime/HalideRuntime.h" namespace llvm { @@ -21,7 +24,6 @@ namespace Halide { struct ExternCFunction; struct JITExtern; -struct Target; class Module; struct JITUserContext; @@ -206,7 +208,7 @@ struct JITModule { * be nullptr for a JITModule which has not yet been compiled or one * that is not a Halide Func compilation at all. */ // @{ - typedef int (*argv_wrapper)(const void **args); + typedef int (*argv_wrapper)(const void *const *args); argv_wrapper argv_function() const; // @} @@ -281,6 +283,50 @@ class JITSharedRuntime { void *get_symbol_address(const char *s); +struct JITCache { + Target jit_target; + // Arguments for all inputs and outputs + std::vector arguments; + std::map jit_externs; + JITModule jit_module; + WasmModule wasm_module; + + JITCache() = default; + JITCache(Target jit_target, + std::vector arguments, + std::map jit_externs, + JITModule jit_module, + WasmModule wasm_module); + + Target get_compiled_jit_target() const; + + int call_jit_code(const Target &target, const void *const *args); + + void finish_profiling(JITUserContext *context); +}; + +struct JITErrorBuffer { + enum { MaxBufSize = 4096 }; + char buf[MaxBufSize]; + std::atomic end{0}; + + void concat(const char *message); + + std::string str() const; + + static void handler(JITUserContext *ctx, const char *message); +}; + +struct JITFuncCallContext { + JITErrorBuffer error_buffer; + JITUserContext *context; + bool custom_error_handler; + + JITFuncCallContext(JITUserContext *context, const JITHandlers &pipeline_handlers); + + void finalize(int exit_status); +}; + } // namespace Internal } // namespace Halide diff --git a/src/LICM.cpp b/src/LICM.cpp index 5cfbdfedb0bc..386a05bd1808 100644 --- a/src/LICM.cpp +++ b/src/LICM.cpp @@ -89,6 +89,10 @@ class LiftLoopInvariants : public IRMutator { return false; } } + if (const Reinterpret *reinterpret = e.as()) { + // Don't lift Reinterpret nodes. They're free. + return should_lift(reinterpret->value); + } if (const Add *add = e.as()) { if (add->type == Int(32) && is_const(add->b)) { @@ -97,8 +101,7 @@ class LiftLoopInvariants : public IRMutator { } } if (const Call *call = e.as()) { - if (Call::as_tag(call) || - call->is_intrinsic(Call::reinterpret)) { + if (Call::as_tag(call)) { // Don't lift these intrinsics. They're free. return should_lift(call->args[0]); } @@ -209,6 +212,8 @@ class LICM : public IRMutator { int cost(const Expr &e, const set &vars) { if (is_const(e)) { return 0; + } else if (const Reinterpret *reinterpret = e.as()) { + return cost(reinterpret->value, vars); } else if (const Variable *var = e.as()) { if (vars.count(var->name)) { // We're loading this already @@ -223,13 +228,6 @@ class LICM : public IRMutator { return cost(sub->a, vars) + cost(sub->b, vars) + 1; } else if (const Mul *mul = e.as()) { return cost(mul->a, vars) + cost(mul->b, vars) + 1; - } else if (const Call *call = e.as()) { - if (call->is_intrinsic(Call::reinterpret)) { - internal_assert(call->args.size() == 1); - return cost(call->args[0], vars); - } else { - return 100; - } } else { return 100; } diff --git a/src/LLVM_Headers.h b/src/LLVM_Headers.h index f23362a2d79f..42f44e0428c4 100644 --- a/src/LLVM_Headers.h +++ b/src/LLVM_Headers.h @@ -1,10 +1,10 @@ #ifndef HALIDE_LLVM_HEADERS_H #define HALIDE_LLVM_HEADERS_H -#if LLVM_VERSION >= 120 +#if LLVM_VERSION >= 130 // We're good to go #else -#error "Compiling Halide requires LLVM 12.0 or newer" +#error "Compiling Halide requires LLVM 13.0 or newer" #endif // No msvc warnings from llvm headers please diff --git a/src/LLVM_Runtime_Linker.cpp b/src/LLVM_Runtime_Linker.cpp index b73e88313132..4995b197b433 100644 --- a/src/LLVM_Runtime_Linker.cpp +++ b/src/LLVM_Runtime_Linker.cpp @@ -91,6 +91,7 @@ DECLARE_CPP_INITMOD(errors) DECLARE_CPP_INITMOD(fake_get_symbol) DECLARE_CPP_INITMOD(fake_thread_pool) DECLARE_CPP_INITMOD(float16_t) +DECLARE_CPP_INITMOD(force_include_types) DECLARE_CPP_INITMOD(fuchsia_clock) DECLARE_CPP_INITMOD(fuchsia_host_cpu_count) DECLARE_CPP_INITMOD(fuchsia_yield) @@ -101,8 +102,6 @@ DECLARE_CPP_INITMOD(ios_io) DECLARE_CPP_INITMOD(linux_clock) DECLARE_CPP_INITMOD(linux_host_cpu_count) DECLARE_CPP_INITMOD(linux_yield) -DECLARE_CPP_INITMOD(matlab) -DECLARE_CPP_INITMOD(metadata) DECLARE_CPP_INITMOD(module_aot_ref_count) DECLARE_CPP_INITMOD(module_jit_ref_count) DECLARE_CPP_INITMOD(msan) @@ -280,10 +279,10 @@ DECLARE_NO_INITMOD(wasm_math) #endif // WITH_WEBASSEMBLY #ifdef WITH_RISCV -//DECLARE_LL_INITMOD(riscv) +// DECLARE_LL_INITMOD(riscv) DECLARE_CPP_INITMOD(riscv_cpu_features) #else -//DECLARE_NO_INITMOD(riscv) +// DECLARE_NO_INITMOD(riscv) DECLARE_NO_INITMOD(riscv_cpu_features) #endif // WITH_RISCV @@ -783,7 +782,7 @@ std::unique_ptr link_with_wasm_jit_runtime(llvm::LLVMContext *c, c modules.push_back(get_initmod_to_string(c, bits_64, debug)); modules.push_back(get_initmod_alignment_32(c, bits_64, debug)); modules.push_back(get_initmod_device_interface(c, bits_64, debug)); - modules.push_back(get_initmod_metadata(c, bits_64, debug)); + modules.push_back(get_initmod_force_include_types(c, bits_64, debug)); modules.push_back(get_initmod_float16_t(c, bits_64, debug)); modules.push_back(get_initmod_errors(c, bits_64, debug)); modules.push_back(get_initmod_msan_stubs(c, bits_64, debug)); @@ -1014,7 +1013,6 @@ std::unique_ptr get_initial_module_for_target(Target t, llvm::LLVM modules.push_back(get_initmod_allocation_cache(c, bits_64, debug)); modules.push_back(get_initmod_device_interface(c, bits_64, debug)); - modules.push_back(get_initmod_metadata(c, bits_64, debug)); modules.push_back(get_initmod_float16_t(c, bits_64, debug)); modules.push_back(get_initmod_errors(c, bits_64, debug)); @@ -1210,16 +1208,14 @@ std::unique_ptr get_initial_module_for_target(Target t, llvm::LLVM } } - if (module_type == ModuleAOT && t.has_feature(Target::Matlab)) { - modules.push_back(get_initmod_matlab(c, bits_64, debug)); - } - if (module_type == ModuleAOTNoRuntime || module_type == ModuleJITInlined || t.os == Target::NoOS) { modules.push_back(get_initmod_runtime_api(c, bits_64, debug)); } + modules.push_back(get_initmod_force_include_types(c, bits_64, debug)); + link_modules(modules, t); if (t.os == Target::Windows && diff --git a/src/Lambda.h b/src/Lambda.h index 55f883d96422..cde63b2efb68 100644 --- a/src/Lambda.h +++ b/src/Lambda.h @@ -47,4 +47,4 @@ Func lambda(const Var &x, const Var &y, const Var &z, const Var &w, const Var &v } // namespace Halide -#endif //HALIDE_LAMBDA_H +#endif // HALIDE_LAMBDA_H diff --git a/src/Lower.cpp b/src/Lower.cpp index 36b8a257e525..38ad867686e6 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -571,7 +571,7 @@ Module lower(const vector &output_funcs, const vector &requirements, bool trace_pipeline, const vector &custom_passes) { - Module result_module{extract_namespaces(pipeline_name), t}; + Module result_module{strip_namespaces(pipeline_name), t}; run_with_large_stack([&]() { lower_impl(output_funcs, pipeline_name, t, args, linkage_type, requirements, trace_pipeline, custom_passes, result_module); }); diff --git a/src/LowerParallelTasks.cpp b/src/LowerParallelTasks.cpp index 73d0dd4c8504..cab6d6c20590 100644 --- a/src/LowerParallelTasks.cpp +++ b/src/LowerParallelTasks.cpp @@ -240,6 +240,8 @@ struct LowerParallelTasks : public IRMutator { const std::string closure_arg_name = unique_name("closure_arg"); auto closure_arg = make_scalar_arg(closure_arg_name); + Type closure_function_type; + std::vector closure_args(use_parallel_for ? 3 : 5); closure_args[0] = make_scalar_arg("__user_context"); if (use_parallel_for) { @@ -247,6 +249,8 @@ struct LowerParallelTasks : public IRMutator { // // typedef int (*halide_task_t)(void *user_context, int task_number, uint8_t *closure); // + closure_function_type = type_of(); + closure_args[1] = make_scalar_arg(t.loop_var); closure_args[2] = closure_arg; // closure_task_parent remains undefined here. @@ -255,6 +259,8 @@ struct LowerParallelTasks : public IRMutator { // // typedef int (*halide_loop_task_t)(void *user_context, int min, int extent, uint8_t *closure, void *task_parent); // + closure_function_type = type_of(); + const std::string closure_task_parent_name = unique_name("__task_parent"); closure_task_parent = Variable::make(type_of(), closure_task_parent_name); // We peeled off a loop. Wrap a new loop around the body @@ -292,7 +298,7 @@ struct LowerParallelTasks : public IRMutator { // TODO(zvookin): Figure out how we want to handle name mangling of closures. // For now, the C++ backend makes them extern "C" so they have to be NameMangling::C. - LoweredFunc closure_func{new_function_name, closure_args, std::move(wrapped_body), LinkageType::External, NameMangling::C}; + LoweredFunc closure_func{new_function_name, closure_args, std::move(wrapped_body), LinkageType::Internal, NameMangling::C}; if (target.has_feature(Target::Debug)) { debug_arguments(&closure_func, target); } @@ -305,7 +311,7 @@ struct LowerParallelTasks : public IRMutator { // case some joker names an intermediate Func or Var the same // name as the pipeline. This prefix works transparently in the // C++ backend. - Expr new_function_name_arg = Variable::make(Handle(), "::" + new_function_name); + Expr new_function_name_arg = Variable::make(closure_function_type, "::" + new_function_name); Expr closure_struct_arg = Cast::make(type_of(), closure_struct); if (use_parallel_for) { @@ -338,7 +344,7 @@ struct LowerParallelTasks : public IRMutator { if (!tasks_array_args.empty()) { // Allocate task list array - Expr tasks_list = Call::make(Handle(), Call::make_struct, tasks_array_args, Call::PureIntrinsic); + Expr tasks_list = Call::make(type_of(), Call::make_struct, tasks_array_args, Call::PureIntrinsic); Expr user_context = Call::make(type_of(), Call::get_user_context, {}, Call::PureIntrinsic); Expr task_parent = has_task_parent ? task_parents.top() : make_zero(Handle()); result = Call::make(Int(32), "halide_do_parallel_tasks", diff --git a/src/MatlabWrapper.cpp b/src/MatlabWrapper.cpp deleted file mode 100644 index f6bd73974c5d..000000000000 --- a/src/MatlabWrapper.cpp +++ /dev/null @@ -1,72 +0,0 @@ -#include "Error.h" -#include "LLVM_Headers.h" - -using namespace llvm; - -namespace Halide { -namespace Internal { - -// Define the mex wrapper API call (mexFunction) for a func with name pipeline_name. -llvm::Function *define_matlab_wrapper(llvm::Module *module, - llvm::Function *pipeline_argv_wrapper, - llvm::Function *metadata_getter) { - user_assert(!module->getFunction("mexFunction")) - << "Module already contains a mexFunction. Only one pipeline can define a mexFunction.\n"; - - LLVMContext &ctx = module->getContext(); - - llvm::Function *call_pipeline = module->getFunction("halide_matlab_call_pipeline"); - internal_assert(call_pipeline) << "Did not find function 'halide_matlab_call_pipeline' in module.\n"; - - llvm::Type *void_ty = llvm::Type::getVoidTy(ctx); - llvm::Type *i8_ty = llvm::Type::getInt8Ty(ctx); - llvm::Type *i32_ty = llvm::Type::getInt32Ty(ctx); - Value *user_context = ConstantPointerNull::get(i8_ty->getPointerTo()); - - llvm::StructType *mxArray_ty = get_llvm_struct_type_by_name(module, "struct.mxArray"); - internal_assert(mxArray_ty) << "Did not find mxArray in initial module"; - llvm::Type *mxArray_ptr_ty = mxArray_ty->getPointerTo(); - llvm::Type *mxArray_ptr_ptr_ty = mxArray_ptr_ty->getPointerTo(); - - // Create the mexFunction function. - // (http://www.mathworks.com/help/matlab/apiref/mexfunction.html) - llvm::Type *mex_arg_types[] = { - i32_ty, - mxArray_ptr_ptr_ty, - i32_ty, - mxArray_ptr_ptr_ty, - }; - FunctionType *mex_ty = FunctionType::get(void_ty, mex_arg_types, false); - llvm::Function *mex = llvm::Function::Create(mex_ty, llvm::GlobalValue::ExternalLinkage, "mexFunction", module); - BasicBlock *entry = BasicBlock::Create(ctx, "entry", mex); - - IRBuilder<> ir(ctx); - ir.SetInsertPoint(entry); - - // Call the metadata_getter function to get the metadata pointer block. - llvm::CallInst *metadata_ptr = ir.CreateCall(metadata_getter); - - // Extract the argument values from the mexFunction. - llvm::Function::arg_iterator mex_args = mex->arg_begin(); - Value *nlhs = iterator_to_pointer(mex_args++); - Value *plhs = iterator_to_pointer(mex_args++); - Value *nrhs = iterator_to_pointer(mex_args++); - Value *prhs = iterator_to_pointer(mex_args++); - - Value *call_pipeline_args[] = { - user_context, - pipeline_argv_wrapper, - metadata_ptr, - nlhs, - plhs, - nrhs, - prhs, - }; - ir.CreateCall(call_pipeline, call_pipeline_args); - ir.CreateRetVoid(); - - return mex; -} - -} // namespace Internal -} // namespace Halide diff --git a/src/MatlabWrapper.h b/src/MatlabWrapper.h deleted file mode 100644 index 24e796dc87d4..000000000000 --- a/src/MatlabWrapper.h +++ /dev/null @@ -1,28 +0,0 @@ -#ifndef HALIDE_MATLAB_OUTPUT_H -#define HALIDE_MATLAB_OUTPUT_H - -/** \file - * - * Provides an output function to generate a Matlab mex API compatible object file. - */ - -namespace llvm { -class Module; -class Function; -class Value; -} // namespace llvm - -namespace Halide { -namespace Internal { - -/** Add a mexFunction wrapper definition to the module, calling the - * function with the name pipeline_name. Returns the mexFunction - * definition. */ -llvm::Function *define_matlab_wrapper(llvm::Module *module, - llvm::Function *pipeline_argv_wrapper, - llvm::Function *metadata_getter); - -} // namespace Internal -} // namespace Halide - -#endif diff --git a/src/Module.cpp b/src/Module.cpp index a8a65f202fd2..e958fabcd804 100644 --- a/src/Module.cpp +++ b/src/Module.cpp @@ -20,8 +20,6 @@ #include "PythonExtensionGen.h" #include "StmtToHtml.h" -using Halide::Internal::debug; - namespace Halide { namespace Internal { @@ -29,14 +27,6 @@ namespace Internal { // and the appropriate file extension for each output type. If you are // explicitly managing file extensions somewhere else, you are probably // doing it wrong; please prefer to use this table as the source of truth. -// -// Note that we deliberately default to ".py.cpp" (rather than .py.c) here for python_extension; -// in theory, the Python extension file we generate can be compiled just -// fine as a plain-C file... but if we are building with cpp-name-mangling -// enabled in the target, we will include generated .h files that can't be compiled. -// We really don't want to vary the file extensions based on target flags, -// and in practice, it's extremely unlikely that anyone needs to rely on this -// being pure C output (vs possibly C++). std::map get_output_info(const Target &target) { constexpr bool IsMulti = true; constexpr bool IsSingle = false; @@ -255,7 +245,11 @@ std::string indent_string(const std::string &src, const std::string &indent) { void emit_schedule_file(const std::string &name, const std::vector &targets, const std::string &scheduler_name, +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API const std::string &machine_params_string, +#else + const std::string &autoscheduler_params_string, +#endif const std::string &body, std::ostream &stream) { std::string s = R"INLINE_CODE(#ifndef $CLEANNAME$_SCHEDULE_H @@ -264,7 +258,7 @@ void emit_schedule_file(const std::string &name, // MACHINE GENERATED -- DO NOT EDIT // This schedule was automatically generated by $SCHEDULER$ // for target=$TARGET$ // NOLINT -// with machine_params=$MACHINEPARAMS$ +// with $MPNAME$=$MACHINEPARAMS$ #include "Halide.h" @@ -318,7 +312,13 @@ inline void apply_schedule_$SHORTNAME$( s = replace_all(s, "$NAMESPACECLOSE$", nsclose); s = replace_all(s, "$TARGET$", target_string); s = replace_all(s, "$BODY$", body_text); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + s = replace_all(s, "$MPNAME$", "machine_params"); s = replace_all(s, "$MACHINEPARAMS$", machine_params_string); +#else + s = replace_all(s, "$MPNAME$", "autoscheduler_params"); + s = replace_all(s, "$MACHINEPARAMS$", autoscheduler_params_string); +#endif stream << s; } @@ -399,8 +399,10 @@ struct ModuleContents { std::vector> buffers; std::vector functions; std::vector submodules; +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE std::vector external_code; - std::map metadata_name_map; +#endif + MetadataNameMap metadata_name_map; bool any_strict_float{false}; std::unique_ptr auto_scheduler_results; }; @@ -485,9 +487,11 @@ const std::vector &Module::submodules() const { return contents->submodules; } +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE const std::vector &Module::external_code() const { return contents->external_code; } +#endif Internal::LoweredFunc Module::get_function_by_name(const std::string &name) const { for (const auto &f : functions()) { @@ -511,9 +515,11 @@ void Module::append(const Module &module) { contents->submodules.push_back(module); } +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE void Module::append(const ExternalCode &external_code) { contents->external_code.push_back(external_code); } +#endif Module link_modules(const std::string &name, const std::vector &modules) { Module output(name, modules.front().target()); @@ -581,12 +587,15 @@ Module Module::resolve_submodules() const { for (const auto &buf : buffers()) { lowered_module.append(buf); } +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE for (const auto &ec : external_code()) { lowered_module.append(ec); } +#endif for (const auto &m : submodules()) { Module copy(m.resolve_submodules()); +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE // Propagate external code blocks. for (const auto &ec : external_code()) { // TODO(zalman): Is this the right thing to do? @@ -601,6 +610,7 @@ Module Module::resolve_submodules() const { copy.append(ec); } } +#endif auto buf = copy.compile_to_buffer(); lowered_module.append(buf); @@ -618,7 +628,7 @@ void Module::remap_metadata_name(const std::string &from, const std::string &to) contents->metadata_name_map[from] = to; } -std::map Module::get_metadata_name_map() const { +MetadataNameMap Module::get_metadata_name_map() const { return contents->metadata_name_map; } @@ -685,7 +695,7 @@ void Module::compile(const std::map &output_files) } } debug(1) << "Module.compile(): static_library " << output_files.at(OutputFileType::static_library) << "\n"; - Target base_target(target().os, target().arch, target().bits); + Target base_target(target().os, target().arch, target().bits, target().processor_tune); create_static_library(temp_dir.files(), base_target, output_files.at(OutputFileType::static_library)); } if (contains(output_files, OutputFileType::assembly)) { @@ -731,17 +741,28 @@ void Module::compile(const std::map &output_files) debug(1) << "Module.compile(): schedule " << output_files.at(OutputFileType::schedule) << "\n"; std::ofstream file(output_files.at(OutputFileType::schedule)); auto *r = contents->auto_scheduler_results.get(); + std::string body = r && !r->schedule_source.empty() ? r->schedule_source : "// No autoscheduler has been run for this Generator.\n"; +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API std::string scheduler = r ? r->scheduler_name : "(None)"; std::string machine_params = r ? r->machine_params_string : "(None)"; - std::string body = r && !r->schedule_source.empty() ? r->schedule_source : "// No autoscheduler has been run for this Generator.\n"; emit_schedule_file(name(), {target()}, scheduler, machine_params, body, file); +#else + std::string scheduler = r ? r->autoscheduler_params.name : "(None)"; + std::string autoscheduler_params_string = r ? r->autoscheduler_params.to_string() : "(None)"; + emit_schedule_file(name(), {target()}, scheduler, autoscheduler_params_string, body, file); +#endif } if (contains(output_files, OutputFileType::python_schedule)) { debug(1) << "Module.compile(): python_schedule " << output_files.at(OutputFileType::python_schedule) << "\n"; std::ofstream file(output_files.at(OutputFileType::python_schedule)); auto *r = contents->auto_scheduler_results.get(); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API std::string scheduler = r ? r->scheduler_name : "(None)"; std::string machine_params = r ? r->machine_params_string : "(None)"; +#else + std::string scheduler = r ? r->autoscheduler_params.name : "(None)"; + std::string machine_params = "(None)"; +#endif std::string body = r && !r->python_schedule_source.empty() ? r->python_schedule_source : "# No autoscheduler has been run for this Generator.\n"; emit_python_schedule_file(name(), {target()}, scheduler, machine_params, body, file); } @@ -754,7 +775,7 @@ void Module::compile(const std::map &output_files) binfile.write((const char *)r->featurization.data(), r->featurization.size()); } binfile.close(); - std::ofstream featurization_index(output_files.at(Output::featurization) + ".index", std::ios::binary | std::ios_base::trunc); + std::ofstream featurization_index(output_files.at(OutputFileType::featurization) + ".index", std::ios::binary | std::ios_base::trunc); if (r) { featurization_index.write((const char *)r->featurization_index.data(), r->featurization_index.size()); } @@ -855,6 +876,11 @@ void compile_multitarget(const std::string &fn_name, user_assert(suffixes.empty() || suffixes.size() == targets.size()) << "The suffixes list must be empty or the same length as the targets list.\n"; + // Some tests were mistakenly passing filenames/pathnames here, which is not kosher + for (char c : "/\\") { + user_assert(fn_name.find(c) == std::string::npos) << "compile_multitarget: fn_name must not contain '" << c << "', but saw '" << fn_name << "'\n"; + } + // The final target in the list is considered "baseline", and is used // for (e.g.) the runtime and shared code. It is often just arch-bits-os // with no other features (though this is *not* a requirement). @@ -937,7 +963,6 @@ void compile_multitarget(const std::string &fn_name, Target::CPlusPlusMangling, Target::Debug, Target::JIT, - Target::Matlab, Target::MSAN, Target::NoRuntime, Target::TSAN, @@ -956,11 +981,7 @@ void compile_multitarget(const std::string &fn_name, std::string sub_fn_name = needs_wrapper ? (fn_name + suffix) : fn_name; // We always produce the runtime separately, so add NoRuntime explicitly. - // Matlab should be added to the wrapper pipeline below, instead of each sub-pipeline. Target sub_fn_target = target.with_feature(Target::NoRuntime); - if (needs_wrapper) { - sub_fn_target = sub_fn_target.without_feature(Target::Matlab); - } { ScopedCompilerLogger activate(compiler_logger_factory, sub_fn_name, sub_fn_target); @@ -1019,7 +1040,7 @@ void compile_multitarget(const std::string &fn_name, // and add that to the result. if (!base_target.has_feature(Target::NoRuntime)) { // Start with a bare Target, set only the features we know are common to all. - Target runtime_target(base_target.os, base_target.arch, base_target.bits); + Target runtime_target(base_target.os, base_target.arch, base_target.bits, base_target.processor_tune); for (int i = 0; i < Target::FeatureEnd; ++i) { // We never want NoRuntime set here. if (i == Target::NoRuntime) { @@ -1059,12 +1080,6 @@ void compile_multitarget(const std::string &fn_name, .with_feature(Target::NoBoundsQuery) .without_feature(Target::NoAsserts); - // If the base target specified the Matlab target, we want the Matlab target - // on the wrapper instead. - if (base_target.has_feature(Target::Matlab)) { - wrapper_target = wrapper_target.with_feature(Target::Matlab); - } - Module wrapper_module(fn_name, wrapper_target); wrapper_module.append(LoweredFunc(fn_name, base_target_args, wrapper_body, LinkageType::ExternalPlusMetadata)); @@ -1096,6 +1111,7 @@ void compile_multitarget(const std::string &fn_name, if (contains(output_files, OutputFileType::schedule)) { debug(1) << "compile_multitarget: schedule " << output_files.at(OutputFileType::schedule) << "\n"; +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API std::string scheduler = auto_scheduler_results.front().scheduler_name; if (scheduler.empty()) { scheduler = "(None)"; @@ -1104,6 +1120,11 @@ void compile_multitarget(const std::string &fn_name, if (machine_params.empty()) { machine_params = "(None)"; } +#else + const auto &autoscheduler_params = auto_scheduler_results.front().autoscheduler_params; + std::string scheduler = autoscheduler_params.name.empty() ? "(None)" : autoscheduler_params.name; + std::string autoscheduler_params_string = autoscheduler_params.name.empty() ? "(None)" : autoscheduler_params.to_string(); +#endif // Find the features that are unique to each stage (vs the baseline case). const auto &baseline_target = auto_scheduler_results.back().target; @@ -1145,15 +1166,24 @@ void compile_multitarget(const std::string &fn_name, } std::ofstream file(output_files.at(OutputFileType::schedule)); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API emit_schedule_file(fn_name, targets, scheduler, machine_params, body.str(), file); +#else + emit_schedule_file(fn_name, targets, scheduler, autoscheduler_params_string, body.str(), file); +#endif } if (contains(output_files, OutputFileType::python_schedule)) { debug(1) << "compile_multitarget: python_schedule " << output_files.at(OutputFileType::python_schedule) << "\n"; +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API std::string scheduler = auto_scheduler_results.front().scheduler_name; + std::string machine_params = auto_scheduler_results.front().machine_params_string; +#else + std::string scheduler = auto_scheduler_results.front().autoscheduler_params.name; + std::string machine_params = ""; +#endif if (scheduler.empty()) { scheduler = "(None)"; } - std::string machine_params = auto_scheduler_results.front().machine_params_string; if (machine_params.empty()) { machine_params = "(None)"; } diff --git a/src/Module.h b/src/Module.h index 1bc5c24f172b..95522f5030db 100644 --- a/src/Module.h +++ b/src/Module.h @@ -13,7 +13,9 @@ #include "Argument.h" #include "Expr.h" +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE #include "ExternalCode.h" +#endif #include "Function.h" // for NameMangling #include "ModulusRemainder.h" @@ -45,42 +47,6 @@ enum class OutputFileType { stmt_html, }; -class HALIDE_ATTRIBUTE_DEPRECATED("Use OutputFileType instead of Output") Output { -public: - HALIDE_ATTRIBUTE_DEPRECATED("Use OutputFileType instead of Output") - static constexpr OutputFileType assembly = OutputFileType::assembly; - HALIDE_ATTRIBUTE_DEPRECATED("Use OutputFileType instead of Output") - static constexpr OutputFileType bitcode = OutputFileType::bitcode; - HALIDE_ATTRIBUTE_DEPRECATED("Use OutputFileType instead of Output") - static constexpr OutputFileType c_header = OutputFileType::c_header; - HALIDE_ATTRIBUTE_DEPRECATED("Use OutputFileType instead of Output") - static constexpr OutputFileType c_source = OutputFileType::c_source; - HALIDE_ATTRIBUTE_DEPRECATED("Use OutputFileType instead of Output") - static constexpr OutputFileType compiler_log = OutputFileType::compiler_log; - HALIDE_ATTRIBUTE_DEPRECATED("Use OutputFileType instead of Output") - static constexpr OutputFileType cpp_stub = OutputFileType::cpp_stub; - HALIDE_ATTRIBUTE_DEPRECATED("Use OutputFileType instead of Output") - static constexpr OutputFileType featurization = OutputFileType::featurization; - HALIDE_ATTRIBUTE_DEPRECATED("Use OutputFileType instead of Output") - static constexpr OutputFileType llvm_assembly = OutputFileType::llvm_assembly; - HALIDE_ATTRIBUTE_DEPRECATED("Use OutputFileType instead of Output") - static constexpr OutputFileType object = OutputFileType::object; - HALIDE_ATTRIBUTE_DEPRECATED("Use OutputFileType instead of Output") - static constexpr OutputFileType python_extension = OutputFileType::python_extension; - HALIDE_ATTRIBUTE_DEPRECATED("Use OutputFileType instead of Output") - static constexpr OutputFileType pytorch_wrapper = OutputFileType::pytorch_wrapper; - HALIDE_ATTRIBUTE_DEPRECATED("Use OutputFileType instead of Output") - static constexpr OutputFileType registration = OutputFileType::registration; - HALIDE_ATTRIBUTE_DEPRECATED("Use OutputFileType instead of Output") - static constexpr OutputFileType schedule = OutputFileType::schedule; - HALIDE_ATTRIBUTE_DEPRECATED("Use OutputFileType instead of Output") - static constexpr OutputFileType static_library = OutputFileType::static_library; - HALIDE_ATTRIBUTE_DEPRECATED("Use OutputFileType instead of Output") - static constexpr OutputFileType stmt = OutputFileType::stmt; - HALIDE_ATTRIBUTE_DEPRECATED("Use OutputFileType instead of Output") - static constexpr OutputFileType stmt_html = OutputFileType::stmt_html; -}; // namespace Output - /** Type of linkage a function in a lowered Halide module can have. Also controls whether auxiliary functions and metadata are generated. */ enum class LinkageType { @@ -169,6 +135,8 @@ class CompilerLogger; struct AutoSchedulerResults; +using MetadataNameMap = std::map; + /** A halide module. This represents IR containing lowered function * definitions and buffers. */ class Module { @@ -197,7 +165,9 @@ class Module { const std::vector &functions() const; std::vector &functions(); const std::vector &submodules() const; +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE const std::vector &external_code() const; +#endif // @} /** Return the function with the given name. If no such function @@ -209,7 +179,9 @@ class Module { void append(const Buffer &buffer); void append(const Internal::LoweredFunc &function); void append(const Module &module); +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE void append(const ExternalCode &external_code); +#endif // @} /** Compile a halide Module to variety of outputs, depending on @@ -230,7 +202,7 @@ class Module { void remap_metadata_name(const std::string &from, const std::string &to) const; /** Retrieve the metadata name map. */ - std::map get_metadata_name_map() const; + MetadataNameMap get_metadata_name_map() const; /** Set the AutoSchedulerResults for the Module. It is an error to call this * multiple times for a given Module. */ diff --git a/src/ModulusRemainder.cpp b/src/ModulusRemainder.cpp index 1e7d49aa3e04..34a598e4c7e3 100644 --- a/src/ModulusRemainder.cpp +++ b/src/ModulusRemainder.cpp @@ -35,6 +35,7 @@ class ComputeModulusRemainder : public IRVisitor { void visit(const FloatImm *) override; void visit(const StringImm *) override; void visit(const Cast *) override; + void visit(const Reinterpret *) override; void visit(const Variable *) override; void visit(const Add *) override; void visit(const Sub *) override; @@ -103,6 +104,10 @@ void ComputeModulusRemainder::visit(const Cast *) { result = ModulusRemainder{}; } +void ComputeModulusRemainder::visit(const Reinterpret *) { + result = ModulusRemainder{}; +} + void ComputeModulusRemainder::visit(const Variable *op) { if (scope.contains(op->name)) { result = scope.get(op->name); diff --git a/src/Monotonic.cpp b/src/Monotonic.cpp index 62910355f5ff..cec309571aa8 100644 --- a/src/Monotonic.cpp +++ b/src/Monotonic.cpp @@ -259,6 +259,10 @@ class DerivativeBounds : public IRVisitor { } } + void visit(const Reinterpret *op) override { + result = ConstantInterval::everything(); + } + void visit(const Variable *op) override { if (op->name == var) { result = ConstantInterval::single_point(1); @@ -476,7 +480,8 @@ class DerivativeBounds : public IRVisitor { } if (op->is_intrinsic(Call::unsafe_promise_clamped) || - op->is_intrinsic(Call::promise_clamped)) { + op->is_intrinsic(Call::promise_clamped) || + op->is_intrinsic(Call::saturating_cast)) { op->args[0].accept(this); return; } diff --git a/src/ParallelRVar.cpp b/src/ParallelRVar.cpp index 8538583cf791..c210e487f3ad 100644 --- a/src/ParallelRVar.cpp +++ b/src/ParallelRVar.cpp @@ -102,6 +102,12 @@ bool can_parallelize_rvar(const string &v, value.accept(&find); } + // add loads from predicate + const Expr pred = simplify(r.predicate()); + if (pred.defined()) { + pred.accept(&find); + } + // Make an expr representing the store done by a different thread. RenameFreeVars renamer; auto other_store = renamer.mutate(args); @@ -139,7 +145,6 @@ bool can_parallelize_rvar(const string &v, } // Add the definition's predicate if there is any - Expr pred = simplify(r.predicate()); if (pred.defined() || !equal(const_true(), pred)) { Expr this_pred = pred; Expr other_pred = renamer.mutate(pred); diff --git a/src/Param.h b/src/Param.h index 9295c45f2a04..03fb58bef6aa 100644 --- a/src/Param.h +++ b/src/Param.h @@ -105,7 +105,7 @@ class Param { } /** Construct a scalar parameter of type T with an initial value of 'val' - * and a given min and max. */ + * and a given min and max. */ Param(not_void_T val, const Expr &min, const Expr &max) : param(type_of(), false, 0, Internal::make_entity_name(this, "Halide:.*:Param<.*>", 'p')) { static_assert(has_static_type, "Cannot use this ctor without an explicit type."); @@ -170,47 +170,49 @@ class Param { /** Set the current value of this parameter. Only meaningful when jitting. Asserts if type is not losslessly-convertible to Parameter's type. */ - // @{ - template::value>::type * = nullptr> - HALIDE_NO_USER_CODE_INLINE void set(const SOME_TYPE &val) { - user_assert(Internal::IsRoundtrippable::value(val)) - << "The value " << val << " cannot be losslessly converted to type " << type(); - param.set_scalar(val); - } - - // Specialized version for when T = void (thus the type is only known at runtime, - // not compiletime). Note that this actually works fine for all Params; we specialize - // it just to reduce code size for the common case of T != void. - template::value>::type * = nullptr> + template HALIDE_NO_USER_CODE_INLINE void set(const SOME_TYPE &val) { -#define HALIDE_HANDLE_TYPE_DISPATCH(CODE, BITS, TYPE) \ - case halide_type_t(CODE, BITS).as_u32(): \ - user_assert(Internal::IsRoundtrippable::value(val)) \ - << "The value " << val << " cannot be losslessly converted to type " << type; \ - param.set_scalar(Internal::StaticCast::value(val)); \ - break; - - const Type type = param.type(); - switch (((halide_type_t)type).element_of().as_u32()) { - HALIDE_HANDLE_TYPE_DISPATCH(halide_type_float, 32, float) - HALIDE_HANDLE_TYPE_DISPATCH(halide_type_float, 64, double) - HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 8, int8_t) - HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 16, int16_t) - HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 32, int32_t) - HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 64, int64_t) - HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 1, bool) - HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 8, uint8_t) - HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 16, uint16_t) - HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 32, uint32_t) - HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 64, uint64_t) - HALIDE_HANDLE_TYPE_DISPATCH(halide_type_handle, 64, uint64_t) // Handle types are always set via set_scalar, not set_scalar - default: - internal_error << "Unsupported type in Param::set<" << type << ">\n"; + if constexpr (!std::is_void::value) { + user_assert(Internal::IsRoundtrippable::value(val)) + << "The value " << val << " cannot be losslessly converted to type " << type(); + param.set_scalar(val); + } else { + // clang-format off + + // Specialized version for when T = void (thus the type is only known at runtime, + // not compiletime). Note that this actually works fine for all Params; we specialize + // it just to reduce code size for the common case of T != void. + + #define HALIDE_HANDLE_TYPE_DISPATCH(CODE, BITS, TYPE) \ + case halide_type_t(CODE, BITS).as_u32(): \ + user_assert(Internal::IsRoundtrippable::value(val)) \ + << "The value " << val << " cannot be losslessly converted to type " << type; \ + param.set_scalar(Internal::StaticCast::value(val)); \ + break; + + const Type type = param.type(); + switch (((halide_type_t)type).element_of().as_u32()) { + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_float, 32, float) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_float, 64, double) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 8, int8_t) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 16, int16_t) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 32, int32_t) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 64, int64_t) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 1, bool) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 8, uint8_t) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 16, uint16_t) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 32, uint32_t) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 64, uint64_t) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_handle, 64, uint64_t) // Handle types are always set via set_scalar, not set_scalar + default: + internal_error << "Unsupported type in Param::set<" << type << ">\n"; + } + + #undef HALIDE_HANDLE_TYPE_DISPATCH + + // clang-format on } - -#undef HALIDE_HANDLE_TYPE_DISPATCH } - // @} /** Get the halide type of the Param */ Type type() const { @@ -249,10 +251,47 @@ class Param { // @} template - void set_estimate(const SOME_TYPE &value) { - user_assert(Internal::IsRoundtrippable::value(value)) - << "The value " << value << " cannot be losslessly converted to type " << type(); - param.set_estimate(Expr(value)); + HALIDE_NO_USER_CODE_INLINE void set_estimate(const SOME_TYPE &val) { + if constexpr (!std::is_void::value) { + user_assert(Internal::IsRoundtrippable::value(val)) + << "The value " << val << " cannot be losslessly converted to type " << type(); + param.set_estimate(Expr(val)); + } else { + // clang-format off + + // Specialized version for when T = void (thus the type is only known at runtime, + // not compiletime). Note that this actually works fine for all Params; we specialize + // it just to reduce code size for the common case of T != void. + + #define HALIDE_HANDLE_TYPE_DISPATCH(CODE, BITS, TYPE) \ + case halide_type_t(CODE, BITS).as_u32(): \ + user_assert(Internal::IsRoundtrippable::value(val)) \ + << "The value " << val << " cannot be losslessly converted to type " << type; \ + param.set_estimate(Expr(Internal::StaticCast::value(val))); \ + break; + + const Type type = param.type(); + switch (((halide_type_t)type).element_of().as_u32()) { + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_float, 32, float) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_float, 64, double) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 8, int8_t) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 16, int16_t) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 32, int32_t) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 64, int64_t) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 1, bool) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 8, uint8_t) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 16, uint16_t) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 32, uint32_t) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 64, uint64_t) + HALIDE_HANDLE_TYPE_DISPATCH(halide_type_handle, 64, uint64_t) // Handle types are always set via set_scalar, not set_scalar + default: + internal_error << "Unsupported type in Param::set<" << type << ">\n"; + } + + #undef HALIDE_HANDLE_TYPE_DISPATCH + + // clang-format on + } } /** You can use this parameter as an expression in a halide diff --git a/src/PartitionLoops.cpp b/src/PartitionLoops.cpp index 68b273a550e6..678752f248f9 100644 --- a/src/PartitionLoops.cpp +++ b/src/PartitionLoops.cpp @@ -746,14 +746,14 @@ class PartitionLoops : public IRMutator { if (make_epilogue) { // Uncomment to include code that prints the epilogue value - //epilogue_val = print(epilogue_val, op->name, "epilogue"); + // epilogue_val = print(epilogue_val, op->name, "epilogue"); stmt = LetStmt::make(epilogue_name, epilogue_val, stmt); } else { epilogue_val = op->min + op->extent; } if (make_prologue) { // Uncomment to include code that prints the prologue value - //prologue_val = print(prologue_val, op->name, "prologue"); + // prologue_val = print(prologue_val, op->name, "prologue"); stmt = LetStmt::make(prologue_name, prologue_val, stmt); } else { prologue_val = op->min; diff --git a/src/Pipeline.cpp b/src/Pipeline.cpp index 5b28d06e678f..6ecd9e6d2cfd 100644 --- a/src/Pipeline.cpp +++ b/src/Pipeline.cpp @@ -2,8 +2,8 @@ #include #include - #include "Argument.h" +#include "Callable.h" #include "CodeGen_Internal.h" #include "FindCalls.h" #include "Func.h" @@ -60,31 +60,66 @@ std::map object_file_outputs(const string &filename return outputs; } +std::string sanitize_function_name(const std::string &s) { + string name = s; + for (char &c : name) { + if (!isalnum(c)) { + c = '_'; + } + } + return name; +} + } // namespace +namespace Internal { + +struct JITCallArgs { + size_t size{0}; + const void **store; + + JITCallArgs(size_t size) + : size(size) { + if (size > kStoreSize) { + store = new ConstVoidPtr[size]; + } else { + store = fixed_store; + } + } + + ~JITCallArgs() { + if (store != fixed_store) { + delete[] store; + } + } + +private: + static constexpr int kStoreSize = 64; + using ConstVoidPtr = const void *; + ConstVoidPtr fixed_store[kStoreSize]; + +public: + JITCallArgs(const JITCallArgs &other) = delete; + JITCallArgs &operator=(const JITCallArgs &other) = delete; + JITCallArgs(JITCallArgs &&other) = delete; + JITCallArgs &operator=(JITCallArgs &&other) = delete; +}; + +} // namespace Internal + struct PipelineContents { mutable RefCount ref_count; // Cached lowered stmt Module module; - // Name of the generated function - string name; - // Cached jit-compiled code - JITModule jit_module; - Target jit_target; - - // Cached compiled JavaScript and/or wasm if defined */ - WasmModule wasm_module; + JITCache jit_cache; /** Clear all cached state */ void invalidate_cache() { module = Module("", Target()); - jit_module = JITModule(); - jit_target = Target(); - inferred_args.clear(); - wasm_module = WasmModule(); + jit_cache = JITCache(); } // The outputs @@ -185,6 +220,7 @@ std::map &Pipeline::get_autoscheduler_map() { return autoschedulers; } +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API /* static */ std::string &Pipeline::get_default_autoscheduler_name() { static std::string autoscheduler_name = ""; @@ -193,6 +229,7 @@ std::string &Pipeline::get_default_autoscheduler_name() { } return autoscheduler_name; } +#endif /* static */ AutoSchedulerFn Pipeline::find_autoscheduler(const std::string &autoscheduler_name) { @@ -209,28 +246,41 @@ AutoSchedulerFn Pipeline::find_autoscheduler(const std::string &autoscheduler_na return it->second; } -std::string simplify_name(const std::string &s) { - // Trim the uniqueness $n suffixes on variables. - return s.substr(0, s.rfind('$')); -} - -AutoSchedulerResults Pipeline::auto_schedule(const std::string &autoscheduler_name, const Target &target, const MachineParams &arch_params) { +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API +AutoSchedulerResults Pipeline::auto_schedule(const std::string &autoscheduler_name, const Target &target, const MachineParams &arch_params) const { auto autoscheduler_fn = find_autoscheduler(autoscheduler_name); user_assert(autoscheduler_fn) << "Could not find autoscheduler named '" << autoscheduler_name << "'.\n" << "Did you remember to load the plugin?"; - autoscheduler_results.target = target; - autoscheduler_results.machine_params_string = arch_params.to_string(); - - autoscheduler_fn(*this, target, arch_params, &autoscheduler_results); + AutoSchedulerResults results; + results.target = target; + results.machine_params_string = arch_params.to_string(); - return autoscheduler_results; + autoscheduler_fn(*this, target, arch_params, &results); + return results; } -AutoSchedulerResults Pipeline::auto_schedule(const Target &target, const MachineParams &arch_params) { +AutoSchedulerResults Pipeline::auto_schedule(const Target &target, const MachineParams &arch_params) const { return auto_schedule(get_default_autoscheduler_name(), target, arch_params); } +#else +AutoSchedulerResults Pipeline::apply_autoscheduler(const Target &target, const AutoschedulerParams &autoscheduler_params) const { + user_assert(!autoscheduler_params.name.empty()) << "apply_autoscheduler was called with no Autoscheduler specified."; + + auto autoscheduler_fn = find_autoscheduler(autoscheduler_params.name); + user_assert(autoscheduler_fn) + << "Could not find autoscheduler named '" << autoscheduler_params.name << "'.\n" + << "Did you remember to load the plugin?"; + + AutoSchedulerResults results; + results.target = target; + results.autoscheduler_params = autoscheduler_params; + + autoscheduler_fn(*this, target, autoscheduler_params, &results); + return results; +} +#endif /* static */ void Pipeline::add_autoscheduler(const std::string &autoscheduler_name, const AutoSchedulerFn &autoscheduler) { @@ -239,11 +289,13 @@ void Pipeline::add_autoscheduler(const std::string &autoscheduler_name, const Au m[autoscheduler_name] = autoscheduler; } +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API /* static */ void Pipeline::set_default_autoscheduler_name(const std::string &autoscheduler_name) { (void)find_autoscheduler(autoscheduler_name); // ensure it's valid get_default_autoscheduler_name() = autoscheduler_name; } +#endif Func Pipeline::get_func(size_t index) { // Compute an environment @@ -533,39 +585,20 @@ Module Pipeline::compile_to_module(const vector &args, std::string Pipeline::generate_function_name() const { user_assert(defined()) << "Pipeline is undefined\n"; - // Come up with a name for a generated function - string name = contents->outputs[0].name(); - for (char &c : name) { - if (!isalnum(c)) { - c = '_'; - } - } - return name; + return sanitize_function_name(contents->outputs[0].name()); } // This essentially is just a getter for contents->jit_target, // but also reality-checks that the status of the jit_module and/or wasm_module // match what we expect. Target Pipeline::get_compiled_jit_target() const { - const bool has_wasm = contents->wasm_module.contents.defined(); - const bool has_native = contents->jit_module.compiled(); - if (contents->jit_target.arch == Target::WebAssembly) { - internal_assert(has_wasm && !has_native); - } else if (!contents->jit_target.has_unknowns()) { - internal_assert(!has_wasm && has_native); - } else { - internal_assert(!has_wasm && !has_native); - } - return contents->jit_target; + return contents->jit_cache.get_compiled_jit_target(); } void Pipeline::compile_jit(const Target &target_arg) { user_assert(defined()) << "Pipeline is undefined\n"; - user_assert(!target_arg.has_unknowns()) << "Cannot compile_jit() for target '" << target_arg << "'\n"; - Target target(target_arg); - target.set_feature(Target::JIT); - target.set_feature(Target::UserContext); + Target target = target_arg.with_feature(Target::JIT).with_feature(Target::UserContext); // If we're re-jitting for the same target, we can just keep the old jit module. if (get_compiled_jit_target() == target) { @@ -574,13 +607,9 @@ void Pipeline::compile_jit(const Target &target_arg) { return; } - debug(2) << "jit-compiling for: " << target_arg << "\n"; - // Clear all cached info in case there is an error. contents->invalidate_cache(); - contents->jit_target = target; - // Infer an arguments vector infer_arguments(); @@ -593,96 +622,82 @@ void Pipeline::compile_jit(const Target &target_arg) { args.push_back(arg.arg); } - // Come up with a name for the generated function - string name = generate_function_name(); - - // Compile to a module and also compile any submodules. - Module module = compile_to_module(args, name, target).resolve_submodules(); - + Module module = compile_to_module(args, generate_function_name(), target).resolve_submodules(); std::map lowered_externs = contents->jit_externs; + contents->jit_cache = compile_jit_cache(module, std::move(args), contents->outputs, contents->jit_externs, target); +} - if (target.arch == Target::WebAssembly) { - FindExterns find_externs(lowered_externs); - for (const LoweredFunc &f : contents->module.functions()) { - f.body.accept(&find_externs); - } - if (debug::debug_level() >= 1) { - for (const auto &p : lowered_externs) { - debug(1) << "Found extern: " << p.first << "\n"; - } - } +Callable Pipeline::compile_to_callable(const std::vector &args_in, const Target &target_arg) { + user_assert(defined()) << "Pipeline is undefined\n"; - vector args_and_outputs = args; - for (auto &out : contents->outputs) { - for (Type t : out.output_types()) { - args_and_outputs.emplace_back(out.name(), Argument::OutputBuffer, t, out.dimensions(), ArgumentEstimates{}); - } - } + Target target = target_arg.with_feature(Target::JIT).with_feature(Target::UserContext); - contents->wasm_module = WasmModule::compile( - module, - args_and_outputs, - contents->module.name(), - lowered_externs, - make_externs_jit_module(target, lowered_externs)); - return; - } + const Argument &user_context_arg = contents->user_context_arg.arg; - auto f = module.get_function_by_name(name); + std::vector args; + args.reserve(args_in.size() + contents->outputs.size() + 1); + // JITUserContext is always the first argument for Callables. + args.push_back(user_context_arg); + for (const Argument &a : args_in) { + user_assert(a.name != user_context_arg.name) << "You may not specify an explicit UserContext Argument to compile_to_callable()."; + args.push_back(a); + } - // Compile to jit module - JITModule jit_module(module, f, make_externs_jit_module(target_arg, lowered_externs)); + Module module = compile_to_module(args, generate_function_name(), target).resolve_submodules(); - // Dump bitcode to a file if the environment variable - // HL_GENBITCODE is defined to a nonzero value. - if (atoi(get_env_variable("HL_GENBITCODE").c_str())) { - string program_name = running_program_name(); - if (program_name.empty()) { - program_name = "unknown" + unique_name('_').substr(1); - } - string file_name = program_name + "_" + name + "_" + unique_name('g').substr(1) + ".bc"; - debug(4) << "Saving bitcode to: " << file_name << "\n"; - module.compile({{OutputFileType::bitcode, file_name}}); - } + auto jit_cache = compile_jit_cache(module, std::move(args), contents->outputs, get_jit_externs(), target); - contents->jit_module = jit_module; + // Save the jit_handlers and jit_externs as they were at the time this + // Callable was created, in case the Pipeline's version is mutated in + // between creation and call -- we want the Callable to remain immutable + // after creation, regardless of what you do to the Func. + return Callable(module.name(), jit_handlers(), get_jit_externs(), std::move(jit_cache)); } -template -void set_handler(A &a, B b) { - a = (A)b; -} +/*static*/ JITCache Pipeline::compile_jit_cache(const Module &module, + std::vector args, + const std::vector &outputs, + const std::map &jit_externs_in, + const Target &target_arg) { + user_assert(!target_arg.has_unknowns()) << "Cannot jit-compile for target '" << target_arg << "'\n"; -void Pipeline::set_error_handler(void (*handler)(void *, const char *)) { - user_assert(defined()) << "Pipeline is undefined\n"; - set_handler(contents->jit_handlers.custom_error, handler); -} + Target jit_target = target_arg.with_feature(Target::JIT).with_feature(Target::UserContext); -void Pipeline::set_custom_allocator(void *(*cust_malloc)(void *, size_t), - void (*cust_free)(void *, void *)) { - user_assert(defined()) << "Pipeline is undefined\n"; - set_handler(contents->jit_handlers.custom_malloc, cust_malloc); - set_handler(contents->jit_handlers.custom_free, cust_free); -} + debug(2) << "jit-compiling for: " << target_arg << "\n"; -void Pipeline::set_custom_do_par_for(int (*cust_do_par_for)(void *, int (*)(void *, int, uint8_t *), int, int, uint8_t *)) { - user_assert(defined()) << "Pipeline is undefined\n"; - set_handler(contents->jit_handlers.custom_do_par_for, cust_do_par_for); -} + for (const auto &out : outputs) { + for (Type t : out.output_types()) { + args.emplace_back(out.name(), Argument::OutputBuffer, t, out.dimensions(), ArgumentEstimates{}); + } + } -void Pipeline::set_custom_do_task(int (*cust_do_task)(void *, int (*)(void *, int, uint8_t *), int, uint8_t *)) { - user_assert(defined()) << "Pipeline is undefined\n"; - set_handler(contents->jit_handlers.custom_do_task, cust_do_task); -} + JITModule jit_module; + WasmModule wasm_module; -void Pipeline::set_custom_trace(int (*trace_fn)(void *, const halide_trace_event_t *)) { - user_assert(defined()) << "Pipeline is undefined\n"; - set_handler(contents->jit_handlers.custom_trace, trace_fn); -} + // Note that make_externs_jit_module() mutates the jit_externs, so we keep a copy + // TODO: it fills in the value side with JITExtern values, but does anything actually use those? + auto jit_externs = jit_externs_in; + std::vector externs_jit_module = Pipeline::make_externs_jit_module(jit_target, jit_externs); + if (jit_target.arch == Target::WebAssembly) { + FindExterns find_externs(jit_externs); + for (const LoweredFunc &f : module.functions()) { + f.body.accept(&find_externs); + } + if (debug::debug_level() >= 1) { + for (const auto &p : jit_externs) { + debug(1) << "Found extern: " << p.first << "\n"; + } + } -void Pipeline::set_custom_print(void (*cust_print)(void *, const char *)) { - user_assert(defined()) << "Pipeline is undefined\n"; - set_handler(contents->jit_handlers.custom_print, cust_print); + wasm_module = WasmModule::compile(module, args, + module.name(), jit_externs, externs_jit_module); + } else { + std::string name = sanitize_function_name(outputs[0].name()); + auto f = module.get_function_by_name(name); + jit_module = JITModule(module, f, externs_jit_module); + } + + return JITCache(jit_target, std::move(args), std::move(jit_externs), std::move(jit_module), std::move(wasm_module)); } void Pipeline::set_jit_externs(const std::map &externs) { @@ -732,12 +747,14 @@ Realization Pipeline::realize(JITUserContext *context, user_assert(defined()) << "Pipeline is undefined\n"; vector> bufs; for (auto &out : contents->outputs) { + user_assert((int)sizes.size() == out.dimensions()) + << "Func " << out.name() << " is defined with " << out.dimensions() << " dimensions, but realize() is requesting a realization with " << sizes.size() << " dimensions.\n"; user_assert(out.has_pure_definition() || out.has_extern_definition()) << "Can't realize Pipeline with undefined output Func: " << out.name() << ".\n"; for (Type t : out.output_types()) { bufs.emplace_back(t, nullptr, sizes); } } - Realization r(bufs); + Realization r(std::move(bufs)); // Do an output bounds query if we can. Otherwise just assume the // output size is good. if (!target.has_feature(Target::NoBoundsQuery)) { @@ -808,127 +825,6 @@ void Pipeline::trace_pipeline() { contents->trace_pipeline = true; } -namespace Internal { -struct JITErrorBuffer { - enum { MaxBufSize = 4096 }; - char buf[MaxBufSize]; - std::atomic end{0}; - - void concat(const char *message) { - size_t len = strlen(message); - - if (len && message[len - 1] != '\n') { - // Claim some extra space for a newline. - len++; - } - - // Atomically claim some space in the buffer - size_t old_end = end.fetch_add(len); - - if (old_end + len >= MaxBufSize - 1) { - // Out of space - return; - } - - for (size_t i = 0; i < len - 1; i++) { - buf[old_end + i] = message[i]; - } - if (buf[old_end + len - 2] != '\n') { - buf[old_end + len - 1] = '\n'; - } - } - - std::string str() const { - return std::string(buf, end); - } - - static void handler(JITUserContext *ctx, const char *message) { - if (ctx && ctx->error_buffer) { - ctx->error_buffer->concat(message); - } - } -}; - -struct JITFuncCallContext { - JITErrorBuffer error_buffer; - JITUserContext *context; - bool custom_error_handler; - - JITFuncCallContext(JITUserContext *context, const JITHandlers &pipeline_handlers) - : context(context) { - custom_error_handler = (context->handlers.custom_error != nullptr || - pipeline_handlers.custom_error != nullptr); - // Hook the error handler if not set - if (!custom_error_handler) { - context->handlers.custom_error = JITErrorBuffer::handler; - } - - // Add the handlers stored in the pipeline for anything else - // not set, then for anything still not set, use the global - // active handlers. - JITSharedRuntime::populate_jit_handlers(context, pipeline_handlers); - context->error_buffer = &error_buffer; - - debug(2) << "custom_print: " << (void *)context->handlers.custom_print << "\n" - << "custom_malloc: " << (void *)context->handlers.custom_malloc << "\n" - << "custom_free: " << (void *)context->handlers.custom_free << "\n" - << "custom_do_task: " << (void *)context->handlers.custom_do_task << "\n" - << "custom_do_par_for: " << (void *)context->handlers.custom_do_par_for << "\n" - << "custom_error: " << (void *)context->handlers.custom_error << "\n" - << "custom_trace: " << (void *)context->handlers.custom_trace << "\n"; - } - - void report_if_error(int exit_status) { - // Only report the errors if no custom error handler was installed - if (exit_status && !custom_error_handler) { - std::string output = error_buffer.str(); - if (output.empty()) { - output = ("The pipeline returned exit status " + - std::to_string(exit_status) + - " but halide_error was never called.\n"); - } - halide_runtime_error << output; - error_buffer.end = 0; - } - } - - void finalize(int exit_status) { - report_if_error(exit_status); - } -}; -} // namespace Internal - -struct Pipeline::JITCallArgs { - size_t size{0}; - const void **store; - - JITCallArgs(size_t size) - : size(size) { - if (size > kStoreSize) { - store = new ConstVoidPtr[size]; - } else { - store = fixed_store; - } - } - - ~JITCallArgs() { - if (store != fixed_store) { - delete[] store; - } - } - -private: - static constexpr int kStoreSize = 64; - using ConstVoidPtr = const void *; - ConstVoidPtr fixed_store[kStoreSize]; - -public: - JITCallArgs(const JITCallArgs &other) = delete; - JITCallArgs &operator=(const JITCallArgs &other) = delete; - JITCallArgs(JITCallArgs &&other) = delete; - JITCallArgs &operator=(JITCallArgs &&other) = delete; -}; - // Make a vector of void *'s to pass to the jit call using the // currently bound value for all of the params and image // params. @@ -945,9 +841,9 @@ void Pipeline::prepare_jit_call_arguments(RealizationArg &outputs, const Target << "Realization requires " << outputs.size() << " output(s) but pipeline produces " << total_outputs << " result(s).\n"; - JITModule &compiled_module = contents->jit_module; + JITModule &compiled_module = contents->jit_cache.jit_module; internal_assert(compiled_module.argv_function() || - contents->wasm_module.contents.defined()); + contents->jit_cache.wasm_module.contents.defined()); const bool no_param_map = ¶m_map == &ParamMap::empty_map(); @@ -1006,7 +902,7 @@ void Pipeline::prepare_jit_call_arguments(RealizationArg &outputs, const Target } } -std::vector +/*static*/ std::vector Pipeline::make_externs_jit_module(const Target &target, std::map &externs_in_out) { std::vector result; @@ -1022,9 +918,9 @@ Pipeline::make_externs_jit_module(const Target &target, // Ensure that the pipeline is compiled. pipeline.compile_jit(target); - free_standing_jit_externs.add_dependency(pipeline_contents.jit_module); - free_standing_jit_externs.add_symbol_for_export(iter.first, pipeline_contents.jit_module.entrypoint_symbol()); - void *address = pipeline_contents.jit_module.entrypoint_symbol().address; + free_standing_jit_externs.add_dependency(pipeline_contents.jit_cache.jit_module); + free_standing_jit_externs.add_symbol_for_export(iter.first, pipeline_contents.jit_cache.jit_module.entrypoint_symbol()); + void *address = pipeline_contents.jit_cache.jit_module.entrypoint_symbol().address; std::vector arg_types; // Add the arguments to the compiled pipeline for (const InferredArgument &arg : pipeline_contents.inferred_args) { @@ -1052,19 +948,7 @@ Pipeline::make_externs_jit_module(const Target &target, } int Pipeline::call_jit_code(const Target &target, const JITCallArgs &args) { -#if defined(__has_feature) -#if __has_feature(memory_sanitizer) - user_warning << "MSAN does not support JIT compilers of any sort, and will report " - "false positives when used in conjunction with the Halide JIT. " - "If you need to test with MSAN enabled, you must use ahead-of-time " - "compilation for Halide code."; -#endif -#endif - if (target.arch == Target::WebAssembly) { - internal_assert(contents->wasm_module.contents.defined()); - return contents->wasm_module.run(args.store); - } - return contents->jit_module.argv_function()(args.store); + return contents->jit_cache.call_jit_code(target, args.store); } void Pipeline::realize(RealizationArg outputs, @@ -1169,19 +1053,7 @@ void Pipeline::realize(JITUserContext *context, debug(2) << "Back from jitted function. Exit status was " << exit_status << "\n"; // If we're profiling, report runtimes and reset profiler stats. - if (target.has_feature(Target::Profile) || target.has_feature(Target::ProfileByTimer)) { - JITModule::Symbol report_sym = - contents->jit_module.find_symbol_by_name("halide_profiler_report"); - JITModule::Symbol reset_sym = - contents->jit_module.find_symbol_by_name("halide_profiler_reset"); - if (report_sym.address && reset_sym.address) { - void (*report_fn_ptr)(JITUserContext *) = (void (*)(JITUserContext *))(report_sym.address); - report_fn_ptr(context); - - void (*reset_fn_ptr)() = (void (*)())(reset_sym.address); - reset_fn_ptr(); - } - } + contents->jit_cache.finish_profiling(context); jit_call_context.finalize(exit_status); } @@ -1206,7 +1078,7 @@ void Pipeline::infer_input_bounds(JITUserContext *context, size_t args_size = contents->inferred_args.size() + outputs.size(); JITCallArgs args(args_size); - prepare_jit_call_arguments(outputs, contents->jit_target, param_map, + prepare_jit_call_arguments(outputs, contents->jit_cache.jit_target, param_map, &context, true, args); struct TrackedBuffer { @@ -1248,8 +1120,8 @@ void Pipeline::infer_input_bounds(JITUserContext *context, } Internal::debug(2) << "Calling jitted function\n"; - int exit_status = call_jit_code(contents->jit_target, args); - jit_context.report_if_error(exit_status); + int exit_status = call_jit_code(contents->jit_cache.jit_target, args); + jit_context.finalize(exit_status); Internal::debug(2) << "Back from jitted function\n"; bool changed = false; @@ -1316,7 +1188,7 @@ void Pipeline::infer_input_bounds(JITUserContext *context, for (Type t : contents->outputs[0].output_types()) { bufs.emplace_back(t, sizes); } - Realization r(bufs); + Realization r(std::move(bufs)); infer_input_bounds(context, r, target, param_map); } @@ -1338,6 +1210,7 @@ JITExtern::JITExtern(const ExternCFunction &extern_c_function) : extern_c_function_(extern_c_function) { } +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API MachineParams MachineParams::generic() { std::string params = Internal::get_env_variable("HL_MACHINE_PARAMS"); if (params.empty()) { @@ -1360,5 +1233,17 @@ MachineParams::MachineParams(const std::string &s) { last_level_cache_size = std::atoll(v[1].c_str()); balance = std::atof(v[2].c_str()); } +#else +std::string AutoschedulerParams::to_string() const { + std::ostringstream os; + if (!name.empty()) { + os << "autoscheduler=" << name; + } + for (const auto &kv : extra) { + os << " autoscheduler." << kv.first << "=" << kv.second; + } + return os.str(); +} +#endif } // namespace Halide diff --git a/src/Pipeline.h b/src/Pipeline.h index 6950f86ed837..4fb223a487a7 100644 --- a/src/Pipeline.h +++ b/src/Pipeline.h @@ -12,7 +12,9 @@ #include #include +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE #include "ExternalCode.h" +#endif #include "IROperator.h" #include "IntrusivePtr.h" #include "JITModule.h" @@ -25,9 +27,11 @@ namespace Halide { struct Argument; +class Callable; class Func; struct PipelineContents; +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API /** A struct representing the machine parameters to generate the auto-scheduled * code for. */ struct MachineParams { @@ -52,9 +56,45 @@ struct MachineParams { /** Reconstruct a MachineParams from canonical string form. */ explicit MachineParams(const std::string &s); }; +#else +/** Special the Autoscheduler to be used (if any), along with arbitrary + * additional arguments specific to the given Autoscheduler. + * + * The 'name' field specifies the type of Autoscheduler + * to be used (e.g. Adams2019, Mullapudi2016). If this is an empty string, + * no autoscheduling will be done; if not, it mustbe the name of a known Autoscheduler. + * + * At this time, well-known autoschedulers include: + * "Mullapudi2016" -- heuristics-based; the first working autoscheduler; currently built in to libHalide + * see http://graphics.cs.cmu.edu/projects/halidesched/ + * "Adams2019" -- aka "the ML autoscheduler"; currently located in apps/autoscheduler + * see https://halide-lang.org/papers/autoscheduler2019.html + * "Li2018" -- aka "the gradient autoscheduler"; currently located in apps/gradient_autoscheduler. + * see https://people.csail.mit.edu/tzumao/gradient_halide + * + * The key/value pairs in 'extra' are defined on a per-autoscheduler basis. + * An autoscheduler can have any number of required or optional keys. + */ +struct AutoschedulerParams { + std::string name; + std::map extra; + + AutoschedulerParams() = default; + /*not-explicit*/ AutoschedulerParams(const std::string &name) + : name(name) { + } + AutoschedulerParams(const std::string &name, const std::map &extra) + : name(name), extra(extra) { + } + + std::string to_string() const; +}; +#endif namespace Internal { class IRMutator; +struct JITCache; +struct JITCallArgs; } // namespace Internal /** @@ -83,10 +123,14 @@ struct CustomLoweringPass { struct JITExtern; struct AutoSchedulerResults { - AutoSchedulerResults(): scheduler_name(), target(), machine_params_string(), schedule_source(), python_schedule_source(), featurization(), path_featurization() {} - std::string scheduler_name; // name of the autoscheduler used - Target target; // Target specified to the autoscheduler - std::string machine_params_string; // MachineParams specified to the autoscheduler (in string form) +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + std::string scheduler_name; // name of the autoscheduler used + Target target; // Target specified to the autoscheduler + std::string machine_params_string; // MachineParams specified to the autoscheduler (in string form) +#else + Target target; // Target specified to the autoscheduler + AutoschedulerParams autoscheduler_params; // The autoscheduler used, along with its params +#endif std::string schedule_source; // The C++ source code of the generated schedule std::string python_schedule_source; // The Python source code of the generated schedule std::vector featurization; // The featurization of the pipeline (if any) @@ -96,7 +140,11 @@ struct AutoSchedulerResults { class Pipeline; +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API using AutoSchedulerFn = std::function; +#else +using AutoSchedulerFn = std::function; +#endif /** A class representing a Halide pipeline. Constructed from the Func * or Funcs that it outputs. */ @@ -145,22 +193,22 @@ class Pipeline { private: Internal::IntrusivePtr contents; - struct JITCallArgs; // Opaque structure to optimize away dynamic allocation in this path. - // For the three method below, precisely one of the first two args should be non-null void prepare_jit_call_arguments(RealizationArg &output, const Target &target, const ParamMap ¶m_map, - JITUserContext **user_context, bool is_bounds_inference, JITCallArgs &args_result); + JITUserContext **user_context, bool is_bounds_inference, Internal::JITCallArgs &args_result); static std::vector make_externs_jit_module(const Target &target, std::map &externs_in_out); static std::map &get_autoscheduler_map(); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API static std::string &get_default_autoscheduler_name(); +#endif static AutoSchedulerFn find_autoscheduler(const std::string &autoscheduler_name); - int call_jit_code(const Target &target, const JITCallArgs &args); + int call_jit_code(const Target &target, const Internal::JITCallArgs &args); // Get the value of contents->jit_target, but reality-check that the contents // sensibly match the value. Return Target() if not jitted. @@ -168,6 +216,12 @@ class Pipeline { AutoSchedulerResults autoscheduler_results; // saving them for future use. + static Internal::JITCache compile_jit_cache(const Module &module, + std::vector args, + const std::vector &outputs, + const std::map &jit_externs, + const Target &target_arg); + public: /** Make an undefined Pipeline object. */ Pipeline(); @@ -185,19 +239,26 @@ class Pipeline { /** Get the Funcs this pipeline outputs. */ std::vector outputs() const; +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API /** Generate a schedule for the pipeline using the currently-default autoscheduler. */ AutoSchedulerResults auto_schedule(const Target &target, - const MachineParams &arch_params = MachineParams::generic()); + const MachineParams &arch_params = MachineParams::generic()) const; /** Generate a schedule for the pipeline using the specified autoscheduler. */ AutoSchedulerResults auto_schedule(const std::string &autoscheduler_name, const Target &target, - const MachineParams &arch_params = MachineParams::generic()); - + const MachineParams &arch_params = MachineParams::generic()) const; +#else + /** Generate a schedule for the pipeline using the specified autoscheduler. */ + AutoSchedulerResults apply_autoscheduler(const Target &target, + const AutoschedulerParams &autoscheduler_params) const; +#endif + /** Add a new the autoscheduler method with the given name. Does not affect the current default autoscheduler. * It is an error to call this with the same name multiple times. */ static void add_autoscheduler(const std::string &autoscheduler_name, const AutoSchedulerFn &autoscheduler); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API /** Globally set the default autoscheduler method to use whenever * autoscheduling any Pipeline when no name is specified. If the autoscheduler_name isn't in the * current table of known autoschedulers, assert-fail. @@ -211,6 +272,7 @@ class Pipeline { * see https://people.csail.mit.edu/tzumao/gradient_halide */ static void set_default_autoscheduler_name(const std::string &autoscheduler_name); +#endif /** Return handle to the index-th Func within the pipeline based on the * topological order. */ @@ -355,27 +417,13 @@ class Pipeline { */ void compile_jit(const Target &target = get_jit_target_from_environment()); - /** Deprecated variants of the above that use a void pointer - * instead of a JITUserContext pointer. */ - // @{ - HALIDE_ATTRIBUTE_DEPRECATED("Custom handlers should by set by modifying the struct returned by jit_handlers()") - void set_error_handler(void (*handler)(void *, const char *)); - HALIDE_ATTRIBUTE_DEPRECATED("Custom handlers should by set by modifying the struct returned by jit_handlers()") - void set_custom_allocator(void *(*malloc)(void *, size_t), - void (*free)(void *, void *)); - HALIDE_ATTRIBUTE_DEPRECATED("Custom handlers should by set by modifying the struct returned by jit_handlers()") - void set_custom_do_task( - int (*custom_do_task)(void *, int (*)(void *, int, uint8_t *), - int, uint8_t *)); - HALIDE_ATTRIBUTE_DEPRECATED("Custom handlers should by set by modifying the struct returned by jit_handlers()") - void set_custom_do_par_for( - int (*custom_do_par_for)(void *, int (*)(void *, int, uint8_t *), int, - int, uint8_t *)); - HALIDE_ATTRIBUTE_DEPRECATED("Custom handlers should by set by modifying the struct returned by jit_handlers()") - void set_custom_trace(int (*trace_fn)(void *, const halide_trace_event_t *)); - HALIDE_ATTRIBUTE_DEPRECATED("Custom handlers should by set by modifying the struct returned by jit_handlers()") - void set_custom_print(void (*handler)(void *, const char *)); - // @} + /** Eagerly jit compile the function to machine code and return a callable + * struct that behaves like a function pointer. The calling convention + * will exactly match that of an AOT-compiled version of this Func + * with the same Argument list. + */ + Callable compile_to_callable(const std::vector &args, + const Target &target = get_jit_target_from_environment()); /** Install a set of external C functions or Funcs to satisfy * dependencies introduced by HalideExtern and define_extern diff --git a/src/Prefetch.h b/src/Prefetch.h index 6a29ea8bcd09..1b8287f881c1 100644 --- a/src/Prefetch.h +++ b/src/Prefetch.h @@ -21,14 +21,14 @@ struct PrefetchDirective; struct Stmt; /** Inject placeholder prefetches to 's'. This placholder prefetch - * does not have explicit region to be prefetched yet. It will be computed - * during call to \ref inject_prefetch. */ + * does not have explicit region to be prefetched yet. It will be computed + * during call to \ref inject_prefetch. */ Stmt inject_placeholder_prefetch(const Stmt &s, const std::map &env, const std::string &prefix, const std::vector &prefetches); /** Compute the actual region to be prefetched and place it to the - * placholder prefetch. Wrap the prefetch call with condition when - * applicable. */ + * placholder prefetch. Wrap the prefetch call with condition when + * applicable. */ Stmt inject_prefetch(const Stmt &s, const std::map &env); /** Reduce a multi-dimensional prefetch into a prefetch of lower dimension diff --git a/src/PythonExtensionGen.cpp b/src/PythonExtensionGen.cpp index bcfa189acf63..0ada655720b3 100644 --- a/src/PythonExtensionGen.cpp +++ b/src/PythonExtensionGen.cpp @@ -42,7 +42,7 @@ bool can_convert(const LoweredArgument *arg) { if (arg->type.is_handle()) { if (arg->name == "__user_context") { /* __user_context is a void* pointer to a user supplied memory region. - * We allow the Python callee to pass PyObject* pointers to that. */ + * We allow the Python callee to pass PyObject* pointers to that. */ return true; } else { return false; @@ -99,36 +99,10 @@ std::pair print_type(const LoweredArgument *arg) { } // namespace -void PythonExtensionGen::convert_buffer(const string &name, const LoweredArgument *arg) { - internal_assert(arg->is_buffer()); - internal_assert(arg->dimensions); - dest << " halide_buffer_t buffer_" << name << ";\n"; - dest << " halide_dimension_t dimensions_" << name << "[" << (int)arg->dimensions << "];\n"; - dest << " Py_buffer view_" << name << ";\n"; - dest << " if (_convert_py_buffer_to_halide("; - dest << /*pyobj*/ "py_" << name << ", "; - dest << /*dimensions*/ (int)arg->dimensions << ", "; - dest << /*flags*/ (arg->is_output() ? "PyBUF_WRITABLE" : "0") << ", "; - dest << /*dim*/ "dimensions_" << name << ", "; - dest << /*out*/ "&buffer_" << name << ", "; - dest << /*buf*/ "view_" << name << ", "; - dest << /*name*/ "\"" << name << "\""; - dest << ") < 0) {\n"; - release_buffers(" "); - dest << " return nullptr;\n"; - dest << " }\n"; -} - PythonExtensionGen::PythonExtensionGen(std::ostream &dest) : dest(dest) { } -void PythonExtensionGen::release_buffers(const string &prefix = " ") { - for (auto &buffer_ref : buffer_refs) { - dest << prefix << "PyBuffer_Release(&" << buffer_ref << ");\n"; - } -} - void PythonExtensionGen::compile(const Module &module) { dest << "#include \"Python.h\"\n"; dest << "#include \"HalideRuntime.h\"\n\n"; @@ -149,112 +123,126 @@ void PythonExtensionGen::compile(const Module &module) { dest << R"INLINE_CODE( /* Older Python versions don't set up PyMODINIT_FUNC correctly. */ #if defined(_MSC_VER) -# define HALIDE_PYTHON_EXPORT __declspec(dllexport) +#define HALIDE_PYTHON_EXPORT __declspec(dllexport) #else -# define HALIDE_PYTHON_EXPORT __attribute__((visibility("default"))) +#define HALIDE_PYTHON_EXPORT __attribute__((visibility("default"))) #endif -#ifdef __cplusplus -extern "C" { -#endif +namespace { -static -#if !defined(_MSC_VER) -__attribute__((unused)) -#endif -int _convert_py_buffer_to_halide( - PyObject* pyobj, int dimensions, int flags, - halide_dimension_t* dim, // array of size `dimensions` - halide_buffer_t* out, Py_buffer &buf, const char* name) { - int ret = PyObject_GetBuffer( - pyobj, &buf, PyBUF_FORMAT | PyBUF_STRIDED_RO | PyBUF_ANY_CONTIGUOUS | flags); - if (ret < 0) { - return ret; - } - if (dimensions && buf.ndim != dimensions) { - PyErr_Format(PyExc_ValueError, "Invalid argument %s: Expected %d dimensions, got %d", - name, dimensions, buf.ndim); - PyBuffer_Release(&buf); - return -1; - } - /* We'll get a buffer that's either: - * C_CONTIGUOUS (last dimension varies the fastest, i.e., has stride=1) or - * F_CONTIGUOUS (first dimension varies the fastest, i.e., has stride=1). - * The latter is preferred, since it's already in the format that Halide - * needs. It can can be achieved in numpy by passing order='F' during array - * creation. However, if we do get a C_CONTIGUOUS buffer, flip the dimensions - * (transpose) so we can process it without having to reallocate. - */ - int i, j, j_step; - if (PyBuffer_IsContiguous(&buf, 'F')) { - j = 0; - j_step = 1; - } else if (PyBuffer_IsContiguous(&buf, 'C')) { - j = buf.ndim - 1; - j_step = -1; - } else { - /* Python checks all dimensions and strides, so this typically indicates - * a bug in the array's buffer protocol. */ - PyErr_Format(PyExc_ValueError, "Invalid buffer: neither C nor Fortran contiguous"); - PyBuffer_Release(&buf); - return -1; - } - for (i = 0; i < buf.ndim; ++i, j += j_step) { - dim[i].min = 0; - dim[i].stride = (int)(buf.strides[j] / buf.itemsize); // strides is in bytes - dim[i].extent = (int)buf.shape[j]; - dim[i].flags = 0; - if (buf.suboffsets && buf.suboffsets[i] >= 0) { - // Halide doesn't support arrays of pointers. But we should never see this - // anyway, since we specified PyBUF_STRIDED. - PyErr_Format(PyExc_ValueError, "Invalid buffer: suboffsets not supported"); - PyBuffer_Release(&buf); - return -1; +template +struct PyHalideBuffer { + // Must allocate at least 1, even if d=0 + static constexpr int dims_to_allocate = (dimensions < 1) ? 1 : dimensions; + + Py_buffer py_buf; + halide_dimension_t halide_dim[dims_to_allocate]; + halide_buffer_t halide_buf; + bool py_buf_needs_release = false; + bool halide_buf_valid = false; + + PyHalideBuffer(PyObject *py_obj, int flags, const char *name) { + memset(&py_buf, 0, sizeof(py_buf)); + if (PyObject_GetBuffer(py_obj, &py_buf, PyBUF_FORMAT | PyBUF_STRIDED_RO | PyBUF_ANY_CONTIGUOUS | flags) < 0) { + PyErr_Format(PyExc_ValueError, "Invalid argument %s: Expected %d dimensions, got %d", name, dimensions, py_buf.ndim); + return; } - } - if (dim[buf.ndim - 1].extent * dim[buf.ndim - 1].stride * buf.itemsize != buf.len) { - PyErr_Format(PyExc_ValueError, "Invalid buffer: length %ld, but computed length %ld", - buf.len, buf.shape[0] * buf.strides[0]); - PyBuffer_Release(&buf); - return -1; - } - *out = halide_buffer_t(); - if (!buf.format) { - out->type.code = halide_type_uint; - out->type.bits = 8; - } else { - /* Convert struct type code. See - * https://docs.python.org/2/library/struct.html#module-struct */ - char* p = buf.format; - while (strchr("@<>!=", *p)) { - p++; // ignore little/bit endian (and alignment) + py_buf_needs_release = true; + + if (dimensions && py_buf.ndim != dimensions) { + PyErr_Format(PyExc_ValueError, "Invalid argument %s: Expected %d dimensions, got %d", name, dimensions, py_buf.ndim); + return; } - if (*p == 'f' || *p == 'd') { - // 'f' and 'd' are float and double, respectively. - out->type.code = halide_type_float; - } else if (*p >= 'a' && *p <= 'z') { - // lowercase is signed int. - out->type.code = halide_type_int; + /* We'll get a buffer that's either: + * C_CONTIGUOUS (last dimension varies the fastest, i.e., has stride=1) or + * F_CONTIGUOUS (first dimension varies the fastest, i.e., has stride=1). + * The latter is preferred, since it's already in the format that Halide + * needs. It can can be achieved in numpy by passing order='F' during array + * creation. However, if we do get a C_CONTIGUOUS buffer, flip the dimensions + * (transpose) so we can process it without having to reallocate. + */ + int i, j, j_step; + if (PyBuffer_IsContiguous(&py_buf, 'F')) { + j = 0; + j_step = 1; + } else if (PyBuffer_IsContiguous(&py_buf, 'C')) { + j = py_buf.ndim - 1; + j_step = -1; } else { - // uppercase is unsigned int. - out->type.code = halide_type_uint; + /* Python checks all dimensions and strides, so this typically indicates + * a bug in the array's buffer protocol. */ + PyErr_Format(PyExc_ValueError, "Invalid buffer: neither C nor Fortran contiguous"); + return; } - const char* type_codes = "bB?hHiIlLqQfd"; // integers and floats - if (strchr(type_codes, *p)) { - out->type.bits = (uint8_t)buf.itemsize * 8; + for (i = 0; i < py_buf.ndim; ++i, j += j_step) { + halide_dim[i].min = 0; + halide_dim[i].stride = (int)(py_buf.strides[j] / py_buf.itemsize); // strides is in bytes + halide_dim[i].extent = (int)py_buf.shape[j]; + halide_dim[i].flags = 0; + if (py_buf.suboffsets && py_buf.suboffsets[i] >= 0) { + // Halide doesn't support arrays of pointers. But we should never see this + // anyway, since we specified PyBUF_STRIDED. + PyErr_Format(PyExc_ValueError, "Invalid buffer: suboffsets not supported"); + return; + } + } + if (halide_dim[py_buf.ndim - 1].extent * halide_dim[py_buf.ndim - 1].stride * py_buf.itemsize != py_buf.len) { + PyErr_Format(PyExc_ValueError, "Invalid buffer: length %ld, but computed length %ld", + py_buf.len, py_buf.shape[0] * py_buf.strides[0]); + return; + } + + memset(&halide_buf, 0, sizeof(halide_buf)); + if (!py_buf.format) { + halide_buf.type.code = halide_type_uint; + halide_buf.type.bits = 8; } else { - // We don't handle 's' and 'p' (char[]) and 'P' (void*) - PyErr_Format(PyExc_ValueError, "Invalid data type for %s: %s", name, buf.format); - PyBuffer_Release(&buf); - return -1; + /* Convert struct type code. See + * https://docs.python.org/2/library/struct.html#module-struct */ + char *p = py_buf.format; + while (strchr("@<>!=", *p)) { + p++; // ignore little/bit endian (and alignment) + } + if (*p == 'f' || *p == 'd') { + // 'f' and 'd' are float and double, respectively. + halide_buf.type.code = halide_type_float; + } else if (*p >= 'a' && *p <= 'z') { + // lowercase is signed int. + halide_buf.type.code = halide_type_int; + } else { + // uppercase is unsigned int. + halide_buf.type.code = halide_type_uint; + } + const char *type_codes = "bB?hHiIlLqQfd"; // integers and floats + if (strchr(type_codes, *p)) { + halide_buf.type.bits = (uint8_t)py_buf.itemsize * 8; + } else { + // We don't handle 's' and 'p' (char[]) and 'P' (void*) + PyErr_Format(PyExc_ValueError, "Invalid data type for %s: %s", name, py_buf.format); + return; + } + } + halide_buf.type.lanes = 1; + halide_buf.dimensions = py_buf.ndim; + halide_buf.dim = halide_dim; + halide_buf.host = (uint8_t *)py_buf.buf; + halide_buf_valid = true; + } + + ~PyHalideBuffer() { + if (py_buf_needs_release) { + PyBuffer_Release(&py_buf); } } - out->type.lanes = 1; - out->dimensions = buf.ndim; - out->dim = dim; - out->host = (uint8_t*)buf.buf; - return 0; -} + + PyHalideBuffer() = delete; + PyHalideBuffer(const PyHalideBuffer &other) = delete; + PyHalideBuffer &operator=(const PyHalideBuffer &other) = delete; + PyHalideBuffer(PyHalideBuffer &&other) = delete; + PyHalideBuffer &operator=(PyHalideBuffer &&other) = delete; +}; + +} // namespace )INLINE_CODE"; @@ -265,7 +253,9 @@ int _convert_py_buffer_to_halide( } dest << "\n"; - dest << "static PyMethodDef _methods[] = {\n"; + dest << "namespace {\n"; + dest << "\n"; + dest << "PyMethodDef _methods[] = {\n"; for (const auto &f : module.functions()) { if (f.linkage == LinkageType::ExternalPlusMetadata) { const string basename = remove_namespaces(f.name); @@ -278,13 +268,19 @@ int _convert_py_buffer_to_halide( dest << R"INLINE_CODE( static_assert(PY_MAJOR_VERSION >= 3, "Python bindings for Halide require Python 3+"); -static struct PyModuleDef _moduledef = { + +struct PyModuleDef _moduledef = { PyModuleDef_HEAD_INIT, MODULE_NAME, nullptr, -1, _methods, }; + +} // namespace + +extern "C" { + HALIDE_PYTHON_EXPORT PyObject* PyInit_)INLINE_CODE"; dest << module.name() << "(void) {"; @@ -293,88 +289,137 @@ HALIDE_PYTHON_EXPORT PyObject* PyInit_)INLINE_CODE"; return PyModule_Create(&_moduledef); } -#ifdef __cplusplus -} -#endif +} // extern "C" + )INLINE_CODE"; } void PythonExtensionGen::compile(const LoweredFunc &f) { const std::vector &args = f.args; const string basename = remove_namespaces(f.name); + std::vector arg_names(args.size()); - dest << "// " << f.name << "\n"; - dest << "static PyObject* _f_" << basename << "(PyObject* module, PyObject* args, PyObject* kwargs) {\n"; for (size_t i = 0; i < args.size(); i++) { arg_names[i] = sanitize_name(args[i].name); - if (!can_convert(&args[i])) { + } + + Indentation indent; + indent.indent = 0; + + dest << "namespace {\n"; + dest << "\n"; + + dest << indent << "const char* const _f_" << basename << "_kwlist[] = {\n"; + indent.indent += 2; + for (size_t i = 0; i < args.size(); i++) { + dest << indent << "\"" << arg_names[i] << "\",\n"; + } + dest << indent << "nullptr\n"; + indent.indent -= 2; + dest << indent << "};\n\n"; + + dest << "// " << f.name << "\n"; + dest << "PyObject* _f_" << basename << "(PyObject* module, PyObject* args, PyObject* kwargs) {\n"; + + indent.indent += 2; + + for (const auto &arg : args) { + if (!can_convert(&arg)) { /* Some arguments can't be converted to Python yet. In those * cases, just add a dummy function that always throws an * Exception. */ // TODO: Add support for handles and vectors. - dest << " PyErr_Format(PyExc_NotImplementedError, " - << "\"Can't convert argument " << args[i].name << " from Python\");\n"; - dest << " return nullptr;\n"; - dest << "}"; + // TODO: might make more sense to simply fail at Halide compile time! + dest << indent << "PyErr_Format(PyExc_NotImplementedError, " + << "\"Can't convert argument " << arg.name << " from Python\");\n"; + dest << indent << "return nullptr;\n"; + dest << "}\n"; + dest << "} // namespace\n"; return; } } - dest << " static const char* const kwlist[] = {"; - for (size_t i = 0; i < args.size(); i++) { - dest << "\"" << arg_names[i] << "\", "; - } - dest << "nullptr};\n"; + for (size_t i = 0; i < args.size(); i++) { - dest << " " << print_type(&args[i]).second << " py_" << arg_names[i] << ";\n"; + dest << indent << print_type(&args[i]).second << " py_" << arg_names[i] << ";\n"; } - dest << " if (!PyArg_ParseTupleAndKeywords(args, kwargs, \""; + dest << indent << "if (!PyArg_ParseTupleAndKeywords(args, kwargs, \""; for (const auto &arg : args) { dest << print_type(&arg).first; } - dest << "\", (char**)kwlist"; + dest << "\", (char**)_f_" << basename << "_kwlist\n"; for (size_t i = 0; i < args.size(); i++) { - dest << ", "; - dest << "&py_" << arg_names[i]; + indent.indent += 2; + dest << indent << ", &py_" << arg_names[i] << "\n"; + indent.indent -= 2; } dest << ")) {\n"; - dest << " return nullptr;\n"; - dest << " }\n"; + indent.indent += 2; + dest << indent << "PyErr_Format(PyExc_ValueError, \"Internal error\");\n"; + dest << indent << "return nullptr;\n"; + indent.indent -= 2; + dest << indent << "}\n"; for (size_t i = 0; i < args.size(); i++) { if (args[i].is_buffer()) { - convert_buffer(arg_names[i], &args[i]); - buffer_refs.push_back("view_" + arg_names[i]); - } else { - // Python already converted this. - } + const auto &name = arg_names[i]; // must use sanitized names here + dest << indent << "PyHalideBuffer<" << (int)args[i].dimensions << "> b_" << name << "(" + << "py_" << name << ", " + << (args[i].is_output() ? "PyBUF_WRITABLE" : "0") << ", " + << "_f_" << basename << "_kwlist[" << i << "]);\n"; + dest << indent << "if (!b_" << name << ".halide_buf_valid) {\n"; + indent.indent += 2; + dest << indent << "return nullptr;\n"; + indent.indent -= 2; + dest << indent << "}\n"; + } // else Python already converted this. } - dest << " int result;\n"; - dest << " Py_BEGIN_ALLOW_THREADS\n"; - dest << " result = " << f.name << "("; + dest << "\n"; + // Mark all input buffers as having a dirty host, so that the Halide call will + // do a lazy-copy-to-GPU if needed. for (size_t i = 0; i < args.size(); i++) { - if (i > 0) { - dest << ", "; + if (args[i].is_buffer() && args[i].is_input()) { + dest << indent << "b_" << arg_names[i] << ".halide_buf.set_host_dirty();\n"; } + } + dest << indent << "int result;\n"; + dest << indent << "Py_BEGIN_ALLOW_THREADS\n"; + dest << indent << "result = " << f.name << "(\n"; + indent.indent += 2; + for (size_t i = 0; i < args.size(); i++) { if (args[i].is_buffer()) { - dest << "&buffer_" << arg_names[i]; + dest << indent << "&b_" << arg_names[i] << ".halide_buf"; } else { - dest << "py_" << arg_names[i]; + dest << indent << "py_" << arg_names[i] << ""; + } + if (i < args.size() - 1) { + dest << ","; } + dest << "\n"; } - dest << ");\n"; - dest << " Py_END_ALLOW_THREADS\n"; - release_buffers(); - dest << R"INLINE_CODE( - if (result != 0) { - /* In the optimal case, we'd be generating an exception declared - * in python_bindings/src, but since we're self-contained, - * we don't have access to that API. */ - PyErr_Format(PyExc_ValueError, "Halide error %d", result); - return nullptr; + indent.indent -= 2; + dest << indent << ");\n"; + dest << indent << "Py_END_ALLOW_THREADS\n"; + // Since the Python Buffer protocol is host-memory-only, we *must* + // flush results back to host, otherwise the output buffer will contain + // random garbage. (We need a better solution for this, see https://github.com/halide/Halide/issues/6868) + for (size_t i = 0; i < args.size(); i++) { + if (args[i].is_buffer() && args[i].is_output()) { + dest << indent << "if (result == 0) result = halide_copy_to_host(nullptr, &b_" << arg_names[i] << ".halide_buf);\n"; + } } - Py_INCREF(Py_True); - return Py_True; -)INLINE_CODE"; + dest << indent << "if (result != 0) {\n"; + indent.indent += 2; + dest << indent << "PyErr_Format(PyExc_ValueError, \"Halide error %d\", result);\n"; + dest << indent << "return nullptr;\n"; + indent.indent -= 2; + dest << indent << "}\n"; + dest << "\n"; + + dest << indent << "Py_INCREF(Py_None);\n"; + dest << indent << "return Py_None;\n"; + indent.indent -= 2; dest << "}\n"; + dest << "\n"; + dest << "} // namespace\n"; } } // namespace Internal diff --git a/src/PythonExtensionGen.h b/src/PythonExtensionGen.h index 15f005f733ed..f7bed3b98892 100644 --- a/src/PythonExtensionGen.h +++ b/src/PythonExtensionGen.h @@ -22,11 +22,8 @@ class PythonExtensionGen { private: std::ostream &dest; - std::vector buffer_refs; void compile(const LoweredFunc &f); - void convert_buffer(const std::string &name, const LoweredArgument *arg); - void release_buffers(const std::string &prefix); }; } // namespace Internal diff --git a/src/Random.cpp b/src/Random.cpp index bb132a9ea536..111ec73ebb5e 100644 --- a/src/Random.cpp +++ b/src/Random.cpp @@ -64,7 +64,7 @@ Expr rng32(const Expr &x) { } // namespace Expr random_int(const vector &e) { - internal_assert(e.size()); + internal_assert(!e.empty()); internal_assert(e[0].type() == Int(32) || e[0].type() == UInt(32)); // Permute the first term Expr result = rng32(cast(UInt(32), e[0])); diff --git a/src/Realization.cpp b/src/Realization.cpp index 9511d5a5a22f..0566eddbae18 100644 --- a/src/Realization.cpp +++ b/src/Realization.cpp @@ -5,36 +5,38 @@ namespace Halide { -/** The number of images in the Realization. */ size_t Realization::size() const { return images.size(); } -/** Get a const reference to one of the images. */ const Buffer &Realization::operator[](size_t x) const { user_assert(x < images.size()) << "Realization access out of bounds\n"; return images[x]; } -/** Get a reference to one of the images. */ Buffer &Realization::operator[](size_t x) { user_assert(x < images.size()) << "Realization access out of bounds\n"; return images[x]; } -/** Construct a Realization that refers to the buffers in an - * existing vector of Buffer<> */ -Realization::Realization(std::vector> &e) +Realization::Realization(const Buffer &e) + : images({e}) { +} + +Realization::Realization(Buffer &&e) + : images({std::move(e)}) { +} + +Realization::Realization(const std::vector> &e) : images(e) { - user_assert(!e.empty()) << "Realizations must have at least one element\n"; + user_assert(!images.empty()) << "Realizations must have at least one element\n"; +} + +Realization::Realization(std::vector> &&e) + : images(std::move(e)) { + user_assert(!images.empty()) << "Realizations must have at least one element\n"; } -/** Call device_sync() for all Buffers in the Realization. - * If one of the calls returns an error, subsequent Buffers won't have - * device_sync called; thus callers should consider a nonzero return - * code to mean that potentially all of the Buffers are in an indeterminate - * state of sync. - * Calling this explicitly should rarely be necessary, except for profiling. */ int Realization::device_sync(void *ctx) { for (auto &b : images) { int result = b.device_sync(ctx); diff --git a/src/Realization.h b/src/Realization.h index 29596d0b3218..867637d44f66 100644 --- a/src/Realization.h +++ b/src/Realization.h @@ -33,7 +33,8 @@ class Realization { /** Single-element realizations are implicitly castable to Buffers. */ template operator Buffer() const { - return images[0].as(); + // use our operator[] overload so that we get proper range-checking + return (*this)[0].as(); } /** Construct a Realization that acts as a reference to some @@ -41,15 +42,35 @@ class Realization { * const. */ template, Args...>::value>::type> - Realization(Buffer &a, Args &&...args) { - images = std::vector>({a, args...}); + HALIDE_ATTRIBUTE_DEPRECATED("Call Realization() with an explicit vector of Buffer<> instead.") + Realization(Buffer &a, Buffer &b, Args &&...args) + : Realization(std::vector>({a, b, args...})) { } + /** Construct a Realization that acts as a reference to a single + * existing Buffer. The element type of the Buffer may not be + * const. */ + // @{ + explicit Realization(const Buffer &e); + explicit Realization(Buffer &&e); + // @} + /** Construct a Realization that refers to the buffers in an - * existing vector of Buffer<> */ - explicit Realization(std::vector> &e); + * existing vector of Buffer<>. The element type of the Buffer(s) may not be + * const */ + // @{ + explicit Realization(const std::vector> &e); + explicit Realization(std::vector> &&e); + // This ctor allows us to avoid ambiguity when the vector is specified as + // a braced literal, e.g. `Realization({first, second})` + explicit Realization(std::initializer_list> e) + : Realization(std::vector>{e}) { + } + // @} /** Call device_sync() for all Buffers in the Realization. * If one of the calls returns an error, subsequent Buffers won't have diff --git a/src/RealizationOrder.cpp b/src/RealizationOrder.cpp index 78bd96234062..c712d3a43238 100644 --- a/src/RealizationOrder.cpp +++ b/src/RealizationOrder.cpp @@ -163,9 +163,16 @@ void populate_fused_pairs_list(const string &func, const Definition &def, func, stage_index, fuse_level.var().name()); if (fuse_level.stage_index() == 0) { parent.definition().schedule().fused_pairs().push_back(pair); + for (auto &s : parent.definition().specializations()) { + s.definition.schedule().fused_pairs().push_back(pair); + } } else { internal_assert(fuse_level.stage_index() > 0); - parent.update(fuse_level.stage_index() - 1).schedule().fused_pairs().push_back(pair); + auto &fuse_stage = parent.update(fuse_level.stage_index() - 1); + fuse_stage.schedule().fused_pairs().push_back(pair); + for (auto &s : fuse_stage.specializations()) { + s.definition.schedule().fused_pairs().push_back(pair); + } } } diff --git a/src/RegionCosts.cpp b/src/RegionCosts.cpp index 38ed44541edc..b7760b54bb8d 100644 --- a/src/RegionCosts.cpp +++ b/src/RegionCosts.cpp @@ -80,6 +80,11 @@ class ExprCost : public IRVisitor { arith += 1; } + void visit(const Reinterpret *op) override { + op->value.accept(this); + // `Reinterpret` is a no-op and does *not* incur any cost. + } + template void visit_binary_operator(const T *op, int op_cost) { op->a.accept(this); @@ -219,7 +224,7 @@ class ExprCost : public IRVisitor { // TODO: Improve the cost model. In some architectures (e.g. ARM or // NEON), count_leading_zeros should be as cheap as bitwise ops. // div_round_to_zero and mod_round_to_zero can also get fairly expensive. - if (call->is_intrinsic(Call::reinterpret) || call->is_intrinsic(Call::bitwise_and) || + if (call->is_intrinsic(Call::bitwise_and) || call->is_intrinsic(Call::bitwise_not) || call->is_intrinsic(Call::bitwise_xor) || call->is_intrinsic(Call::bitwise_or) || call->is_intrinsic(Call::shift_left) || call->is_intrinsic(Call::shift_right) || call->is_intrinsic(Call::div_round_to_zero) || @@ -371,7 +376,7 @@ Expr get_func_value_size(const Function &f) { Cost compute_expr_cost(Expr expr) { // TODO: Handle likely - //expr = LikelyExpression().mutate(expr); + // expr = LikelyExpression().mutate(expr); expr = simplify(expr); ExprCost cost_visitor; expr.accept(&cost_visitor); @@ -380,7 +385,7 @@ Cost compute_expr_cost(Expr expr) { map compute_expr_detailed_byte_loads(Expr expr) { // TODO: Handle likely - //expr = LikelyExpression().mutate(expr); + // expr = LikelyExpression().mutate(expr); expr = simplify(expr); ExprCost cost_visitor; expr.accept(&cost_visitor); @@ -508,7 +513,7 @@ RegionCosts::stage_detailed_load_costs(const string &func, int stage, if (curr_f.has_extern_definition()) { // TODO(psuriana): We need a better cost for extern function - //load_costs.emplace(func, Int(64).max()); + // load_costs.emplace(func, Int(64).max()); load_costs.emplace(func, Expr()); } else { Definition def = get_stage_definition(curr_f, stage); diff --git a/src/RemoveUndef.cpp b/src/RemoveUndef.cpp index d96c503b7085..a4889f6cc3b5 100644 --- a/src/RemoveUndef.cpp +++ b/src/RemoveUndef.cpp @@ -59,6 +59,18 @@ class RemoveUndef : public IRMutator { } } + Expr visit(const Reinterpret *op) override { + Expr value = mutate(op->value); + if (!value.defined()) { + return Expr(); + } + if (value.same_as(op->value)) { + return op; + } else { + return Reinterpret::make(op->type, std::move(value)); + } + } + Expr visit(const Add *op) override { return mutate_binary_operator(op); } diff --git a/src/Schedule.h b/src/Schedule.h index 0b82c8358e8e..186506cbb593 100644 --- a/src/Schedule.h +++ b/src/Schedule.h @@ -282,8 +282,8 @@ struct Split { std::string old_var, outer, inner; Expr factor; bool exact; // Is it required that the factor divides the extent - // of the old var. True for splits of RVars. Forces - // tail strategy to be GuardWithIf. + // of the old var. True for splits of RVars. Forces + // tail strategy to be GuardWithIf. TailStrategy tail; enum SplitType { SplitVar = 0, diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index 52724e309ba6..4881da7c3d54 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -2369,9 +2369,18 @@ void validate_fused_groups_schedule(const vector> &fused_groups, validate_fused_group_schedule_helper( iter->first, 0, iter->second.definition(), env); + for (const auto &s : iter->second.definition().specializations()) { + validate_fused_group_schedule_helper( + iter->first, 0, s.definition, env); + } for (size_t i = 0; i < iter->second.updates().size(); ++i) { + const auto &update_stage = iter->second.updates()[i]; validate_fused_group_schedule_helper( - iter->first, i + 1, iter->second.updates()[i], env); + iter->first, i + 1, update_stage, env); + for (const auto &s : update_stage.specializations()) { + validate_fused_group_schedule_helper( + iter->first, i + 1, s.definition, env); + } } } } diff --git a/src/Simplify_Call.cpp b/src/Simplify_Call.cpp index b3d2282ffb56..117e2b63a47a 100644 --- a/src/Simplify_Call.cpp +++ b/src/Simplify_Call.cpp @@ -1,5 +1,6 @@ #include "Simplify_Internal.h" +#include "FindIntrinsics.h" #include "Simplify.h" #ifdef _MSC_VER @@ -132,7 +133,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { // If we know the sign of this shift, change it to an unsigned shift. if (b_info.min_defined && b_info.min >= 0) { b = mutate(cast(b.type().with_code(halide_type_uint), b), nullptr); - } else if (b_info.max_defined && b_info.max <= 0) { + } else if (b.type().is_int() && b_info.max_defined && b_info.max <= 0) { result_op = Call::get_intrinsic_name(op->is_intrinsic(Call::shift_right) ? Call::shift_left : Call::shift_right); b = mutate(cast(b.type().with_code(halide_type_uint), -b), nullptr); } @@ -165,12 +166,14 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { } } - // Rewrite shifts with negated RHSes as shifts of the other direction. - if (const Sub *sub = b.as()) { - if (is_const_zero(sub->a)) { - result_op = Call::get_intrinsic_name(op->is_intrinsic(Call::shift_right) ? Call::shift_left : Call::shift_right); - b = sub->b; - return mutate(Call::make(op->type, result_op, {a, b}, Call::PureIntrinsic), bounds); + // Rewrite shifts with signed negated RHSes as shifts of the other direction. + if (b.type().is_int()) { + if (const Sub *sub = b.as()) { + if (is_const_zero(sub->a)) { + result_op = Call::get_intrinsic_name(op->is_intrinsic(Call::shift_right) ? Call::shift_left : Call::shift_right); + b = sub->b; + return mutate(Call::make(op->type, result_op, {a, b}, Call::PureIntrinsic), bounds); + } } } @@ -280,25 +283,6 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { } else { return a ^ b; } - } else if (op->is_intrinsic(Call::reinterpret)) { - Expr a = mutate(op->args[0], nullptr); - - int64_t ia; - uint64_t ua; - bool vector = op->type.is_vector() || a.type().is_vector(); - if (op->type == a.type()) { - return a; - } else if (const_int(a, &ia) && op->type.is_uint() && !vector) { - // int -> uint - return make_const(op->type, (uint64_t)ia); - } else if (const_uint(a, &ua) && op->type.is_int() && !vector) { - // uint -> int - return make_const(op->type, (int64_t)ua); - } else if (a.same_as(op->args[0])) { - return op; - } else { - return reinterpret(op->type, a); - } } else if (op->is_intrinsic(Call::abs)) { // Constant evaluate abs(x). ExprInfo a_bounds; @@ -368,6 +352,21 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { } else { return absd(a, b); } + } else if (op->is_intrinsic(Call::saturating_cast)) { + internal_assert(op->args.size() == 1); + ExprInfo a_bounds; + Expr a = mutate(op->args[0], &a_bounds); + + // TODO(rootjalex): We could be intelligent about using a_bounds to remove saturating_casts; + + if (is_const(a)) { + a = lower_saturating_cast(op->type, a); + return mutate(a, bounds); + } else if (!a.same_as(op->args[0])) { + return saturating_cast(op->type, a); + } else { + return op; + } } else if (op->is_intrinsic(Call::stringify)) { // Eagerly concat constant arguments to a stringify. bool changed = false; @@ -777,6 +776,8 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { debug(2) << "Simplifier: unhandled PureExtern: " << op->name; } else if (op->is_intrinsic(Call::signed_integer_overflow)) { clear_bounds_info(bounds); + } else if (op->is_intrinsic(Call::concat_bits) && op->args.size() == 1) { + return mutate(op->args[0], bounds); } // No else: we want to fall thru from the PureExtern clause. diff --git a/src/Simplify_Internal.h b/src/Simplify_Internal.h index 0be084d61154..a510e5c51f64 100644 --- a/src/Simplify_Internal.h +++ b/src/Simplify_Internal.h @@ -309,6 +309,7 @@ class Simplify : public VariadicVisitor { Expr visit(const StringImm *op, ExprInfo *bounds); Expr visit(const Broadcast *op, ExprInfo *bounds); Expr visit(const Cast *op, ExprInfo *bounds); + Expr visit(const Reinterpret *op, ExprInfo *bounds); Expr visit(const Variable *op, ExprInfo *bounds); Expr visit(const Add *op, ExprInfo *bounds); Expr visit(const Sub *op, ExprInfo *bounds); diff --git a/src/Simplify_Reinterpret.cpp b/src/Simplify_Reinterpret.cpp new file mode 100644 index 000000000000..c5d8d07ce233 --- /dev/null +++ b/src/Simplify_Reinterpret.cpp @@ -0,0 +1,28 @@ +#include "Simplify_Internal.h" + +namespace Halide { +namespace Internal { + +Expr Simplify::visit(const Reinterpret *op, ExprInfo *bounds) { + Expr a = mutate(op->value, nullptr); + + int64_t ia; + uint64_t ua; + bool vector = op->type.is_vector() || a.type().is_vector(); + if (op->type == a.type()) { + return a; + } else if (const_int(a, &ia) && op->type.is_uint() && !vector) { + // int -> uint + return make_const(op->type, (uint64_t)ia); + } else if (const_uint(a, &ua) && op->type.is_int() && !vector) { + // uint -> int + return make_const(op->type, (int64_t)ua); + } else if (a.same_as(op->value)) { + return op; + } else { + return reinterpret(op->type, a); + } +} + +} // namespace Internal +} // namespace Halide diff --git a/src/Simplify_Shuffle.cpp b/src/Simplify_Shuffle.cpp index 55469cd32817..35622aee9c4e 100644 --- a/src/Simplify_Shuffle.cpp +++ b/src/Simplify_Shuffle.cpp @@ -1,4 +1,5 @@ #include "Deinterleave.h" +#include "IROperator.h" #include "Simplify_Internal.h" namespace Halide { @@ -191,7 +192,34 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *bounds) { } } } + + // Try to collapse an interleave of a series of extract_bits into a vector reinterpret. + if (const Call *extract = new_vectors[0].as()) { + if (extract->is_intrinsic(Call::extract_bits) && + is_const_zero(extract->args[1])) { + int n = (int)new_vectors.size(); + Expr base = extract->args[0]; + bool can_collapse = base.type().bits() == n * op->type.bits(); + for (int i = 1; can_collapse && i < n; i++) { + const Call *c = new_vectors[i].as(); + if (!(c->is_intrinsic(Call::extract_bits) && + is_const(c->args[1], i * op->type.bits()) && + equal(base, c->args[0]))) { + can_collapse = false; + } + } + if (can_collapse) { + return Reinterpret::make(op->type, base); + } + } + } + } else if (op->is_concat()) { + // Bypass concat of a single vector (identity shuffle) + if (new_vectors.size() == 1) { + return new_vectors[0]; + } + // Try to collapse a concat of ramps into a single ramp. const Ramp *r = new_vectors[0].as(); if (r) { diff --git a/src/Solve.cpp b/src/Solve.cpp index 0b632ac52e45..d8ff919bb56c 100644 --- a/src/Solve.cpp +++ b/src/Solve.cpp @@ -1124,6 +1124,10 @@ class SolveForInterval : public IRVisitor { fail(); } + void visit(const Reinterpret *op) override { + fail(); + } + void visit(const Load *op) override { fail(); } diff --git a/src/SpirvIR.cpp b/src/SpirvIR.cpp new file mode 100644 index 000000000000..621e79de7c62 --- /dev/null +++ b/src/SpirvIR.cpp @@ -0,0 +1,1736 @@ +#include "SpirvIR.h" +#include + +#ifdef WITH_SPIRV + +namespace Halide { +namespace Internal { + +/** SpvInstruction implementation **/ +SpvInstruction SpvInstruction::make(SpvOp op_code) { + SpvInstruction instance; + instance.contents = SpvInstructionContentsPtr(new SpvInstructionContents); + instance.contents->op_code = op_code; + instance.contents->result_id = SpvNoResult; + instance.contents->type_id = SpvNoType; + return instance; +} + +void SpvInstruction::set_block(SpvBlock block) { + check_defined(); + contents->block = std::move(block); +} + +void SpvInstruction::set_result_id(SpvId result_id) { + check_defined(); + contents->result_id = result_id; +} + +void SpvInstruction::set_type_id(SpvId type_id) { + check_defined(); + contents->type_id = type_id; +} + +void SpvInstruction::set_op_code(SpvOp op_code) { + check_defined(); + contents->op_code = op_code; +} + +void SpvInstruction::add_operand(SpvId id) { + check_defined(); + contents->operands.push_back(id); + contents->immediates.push_back(false); +} + +void SpvInstruction::add_immediate(SpvId id) { + check_defined(); + contents->operands.push_back(id); + contents->immediates.push_back(true); +} + +SpvId SpvInstruction::result_id() const { + check_defined(); + return contents->result_id; +} + +SpvId SpvInstruction::type_id() const { + check_defined(); + return contents->type_id; +} + +SpvOp SpvInstruction::op_code() const { + check_defined(); + return contents->op_code; +} + +SpvId SpvInstruction::operand(uint32_t index) { + check_defined(); + return contents->operands[index]; +} + +bool SpvInstruction::has_type() const { + if (!is_defined()) { + return false; + } + return contents->type_id != SpvNoType; +} + +bool SpvInstruction::has_result() const { + if (!is_defined()) { + return false; + } + return contents->result_id != SpvNoResult; +} + +bool SpvInstruction::is_defined() const { + return contents.defined(); +} + +bool SpvInstruction::is_immediate(uint32_t index) const { + check_defined(); + return contents->immediates[index]; +} + +uint32_t SpvInstruction::length() const { + check_defined(); + return (uint32_t)contents->operands.size(); +} + +SpvBlock SpvInstruction::block() const { + check_defined(); + return contents->block; +} + +void SpvInstruction::add_data(uint32_t bytes, const void *data) { + check_defined(); + uint32_t extra_words = (bytes + 3) / 4; + const uint8_t *ptr = (const uint8_t *)data; + size_t bytes_copied = 0; + for (uint32_t i = 0; i < extra_words; i++) { + size_t copy_size = std::min(bytes - bytes_copied, (size_t)4); + SpvId entry = 0; + memcpy(&entry, ptr, copy_size); + bytes_copied += copy_size; + add_immediate(entry); + ptr++; + } +} + +void SpvInstruction::add_string(const std::string &str) { + check_defined(); + add_data(str.length() + 1, (const void *)str.c_str()); +} + +void SpvInstruction::check_defined() const { + user_assert(is_defined()) << "An SpvInstruction must be defined before accessing its properties\n"; +} + +void SpvInstruction::encode(SpvBinary &binary) const { + check_defined(); + + // Count the number of 32-bit words to represent the instruction + uint32_t word_count = 1; + word_count += has_type() ? 1 : 0; + word_count += has_result() ? 1 : 0; + word_count += length(); + + // Preface the instruction with the format + // - high 16-bits indicate instruction length (number of 32-bit words) + // - low 16-bits indicate op code + binary.push_back(((word_count) << SpvWordCountShift) | contents->op_code); + if (has_type()) { + binary.push_back(contents->type_id); + } + if (has_result()) { + binary.push_back(contents->result_id); + } + for (SpvId id : contents->operands) { + binary.push_back(id); + } +} + +// -- + +SpvBlock SpvBlock::make(SpvFunction func, SpvId block_id) { + SpvBlock instance; + instance.contents = SpvBlockContentsPtr(new SpvBlockContents()); + instance.contents->parent = std::move(func); + instance.contents->block_id = block_id; + return instance; +} + +void SpvBlock::add_instruction(SpvInstruction inst) { + check_defined(); + inst.set_block(*this); + contents->instructions.push_back(inst); +} + +void SpvBlock::add_variable(SpvInstruction var) { + check_defined(); + var.set_block(*this); + contents->instructions.push_back(var); +} + +void SpvBlock::set_function(SpvFunction func) { + check_defined(); + contents->parent = std::move(func); +} + +SpvFunction SpvBlock::function() const { + check_defined(); + return contents->parent; +} + +const SpvBlock::Instructions &SpvBlock::instructions() const { + check_defined(); + return contents->instructions; +} + +const SpvBlock::Variables &SpvBlock::variables() const { + check_defined(); + return contents->variables; +} + +bool SpvBlock::is_reachable() const { + check_defined(); + return contents->reachable; +} + +bool SpvBlock::is_defined() const { + return contents.defined(); +} + +bool SpvBlock::is_terminated() const { + check_defined(); + switch (contents->instructions.back().op_code()) { + case SpvOpBranch: + case SpvOpBranchConditional: + case SpvOpSwitch: + case SpvOpKill: + case SpvOpReturn: + case SpvOpReturnValue: + case SpvOpUnreachable: + return true; + default: + return false; + }; +} + +SpvId SpvBlock::id() const { + check_defined(); + return contents->block_id; +} + +void SpvBlock::check_defined() const { + user_assert(is_defined()) << "An SpvBlock must be defined before accessing its properties\n"; +} + +void SpvBlock::encode(SpvBinary &binary) const { + check_defined(); + + // add a label for this block + SpvInstruction label = SpvFactory::label(contents->block_id); + label.encode(binary); + + // encode all variables + for (const SpvInstruction &variable : contents->variables) { + variable.encode(binary); + } + // encode all instructions + for (const SpvInstruction &instruction : contents->instructions) { + instruction.encode(binary); + } +} + +// -- + +SpvFunction SpvFunction::make(SpvId func_type_id, SpvId func_id, SpvId return_type_id, uint32_t control_mask) { + SpvFunction instance; + instance.contents = SpvFunctionContentsPtr(new SpvFunctionContents()); + instance.contents->function_id = func_id; + instance.contents->function_type_id = func_type_id; + instance.contents->return_type_id = return_type_id; + instance.contents->control_mask = control_mask; + instance.contents->declaration = SpvFactory::function(return_type_id, func_id, control_mask, func_type_id); + return instance; +} + +bool SpvFunction::is_defined() const { + return contents.defined(); +} + +void SpvFunction::add_block(const SpvBlock &block) { + check_defined(); + contents->blocks.push_back(block); +} + +void SpvFunction::add_parameter(const SpvInstruction ¶m) { + check_defined(); + contents->parameters.push_back(param); +} + +uint32_t SpvFunction::parameter_count() const { + check_defined(); + return (uint32_t)contents->parameters.size(); +} + +SpvBlock SpvFunction::entry_block() const { + check_defined(); + return contents->blocks.front(); +} + +SpvPrecision SpvFunction::return_precision() const { + check_defined(); + SpvId return_id = contents->declaration.result_id(); + SpvFunctionContents::PrecisionMap::const_iterator it = contents->precision.find(return_id); + if (it == contents->precision.end()) { + return SpvPrecision::SpvFullPrecision; + } else { + return contents->precision[return_id]; + } +} + +void SpvFunction::set_return_precision(SpvPrecision precision) { + check_defined(); + SpvId return_id = contents->declaration.result_id(); + SpvFunctionContents::PrecisionMap::const_iterator it = contents->precision.find(return_id); + if (it == contents->precision.end()) { + contents->precision.insert({return_id, precision}); + } else { + contents->precision[return_id] = precision; + } +} + +SpvPrecision SpvFunction::parameter_precision(uint32_t index) const { + check_defined(); + user_assert(contents->parameters.size() > index) << "Invalid parameter index specified!\n"; + SpvId param_id = contents->parameters[index].result_id(); + SpvFunctionContents::PrecisionMap::const_iterator it = contents->precision.find(param_id); + if (it == contents->precision.end()) { + return SpvPrecision::SpvFullPrecision; + } else { + return contents->precision[param_id]; + } +} + +void SpvFunction::set_module(SpvModule module) { + check_defined(); + contents->parent = std::move(module); +} + +SpvInstruction SpvFunction::declaration() const { + check_defined(); + return contents->declaration; +} + +SpvModule SpvFunction::module() const { + check_defined(); + return contents->parent; +} + +SpvId SpvFunction::return_type_id() const { + check_defined(); + return contents->return_type_id; +} + +SpvId SpvFunction::type_id() const { + check_defined(); + return contents->function_type_id; +} + +SpvId SpvFunction::id() const { + check_defined(); + return contents->function_id; +} + +void SpvFunction::check_defined() const { + user_assert(is_defined()) << "An SpvFunction must be defined before accessing its properties\n"; +} + +void SpvFunction::encode(SpvBinary &binary) const { + check_defined(); + contents->declaration.encode(binary); + for (const SpvInstruction ¶m : contents->parameters) { + param.encode(binary); + } + for (const SpvBlock &block : contents->blocks) { + block.encode(binary); + } + + SpvInstruction inst = SpvFactory::function_end(); + inst.encode(binary); +} + +// -- + +SpvModule SpvModule::make(SpvId module_id, + SpvSourceLanguage source_language, + SpvAddressingModel addressing_model, + SpvMemoryModel memory_model) { + SpvModule instance; + instance.contents = SpvModuleContentsPtr(new SpvModuleContents()); + instance.contents->module_id = module_id; + instance.contents->source_language = source_language; + instance.contents->addressing_model = addressing_model; + instance.contents->memory_model = memory_model; + return instance; +} + +bool SpvModule::is_defined() const { + return contents.defined(); +} + +void SpvModule::add_debug(const SpvInstruction &val) { + check_defined(); + contents->debug.push_back(val); +} + +void SpvModule::add_annotation(const SpvInstruction &val) { + check_defined(); + contents->annotations.push_back(val); +} + +void SpvModule::add_type(const SpvInstruction &val) { + check_defined(); + contents->types.push_back(val); +} + +void SpvModule::add_constant(const SpvInstruction &val) { + check_defined(); + contents->constants.push_back(val); +} + +void SpvModule::add_global(const SpvInstruction &val) { + check_defined(); + contents->globals.push_back(val); +} + +void SpvModule::add_execution_mode(const SpvInstruction &val) { + check_defined(); + contents->execution_modes.push_back(val); +} + +void SpvModule::add_instruction(const SpvInstruction &val) { + check_defined(); + contents->instructions.push_back(val); +} + +void SpvModule::add_function(SpvFunction val) { + check_defined(); + val.set_module(*this); + contents->functions.emplace_back(val); +} + +void SpvModule::add_entry_point(const std::string &name, SpvInstruction inst) { + check_defined(); + contents->entry_points[name] = std::move(inst); +} + +void SpvModule::set_source_language(SpvSourceLanguage val) { + check_defined(); + contents->source_language = val; +} + +void SpvModule::set_addressing_model(SpvAddressingModel val) { + check_defined(); + contents->addressing_model = val; +} + +void SpvModule::set_memory_model(SpvMemoryModel val) { + check_defined(); + contents->memory_model = val; +} + +SpvSourceLanguage SpvModule::source_language() const { + check_defined(); + return contents->source_language; +} + +SpvAddressingModel SpvModule::addressing_model() const { + check_defined(); + return contents->addressing_model; +} + +const SpvModule::Instructions &SpvModule::execution_modes() const { + check_defined(); + return contents->execution_modes; +} + +SpvMemoryModel SpvModule::memory_model() const { + check_defined(); + return contents->memory_model; +} + +SpvInstruction SpvModule::entry_point(const std::string &name) const { + check_defined(); + if (contents->entry_points.find(name) != contents->entry_points.end()) { + return contents->entry_points[name]; + } else { + SpvInstruction noop = SpvInstruction::make(SpvOpNop); + return noop; + } +} + +void SpvModule::require_extension(const std::string &extension) { + check_defined(); + if (contents->extensions.find(extension) == contents->extensions.end()) { + contents->extensions.insert(extension); + } +} + +bool SpvModule::is_extension_required(const std::string &extension) const { + check_defined(); + if (contents->extensions.find(extension) != contents->extensions.end()) { + return true; + } + return false; +} + +void SpvModule::require_capability(SpvCapability capability) { + check_defined(); + if (contents->capabilities.find(capability) == contents->capabilities.end()) { + contents->capabilities.insert(capability); + } +} + +bool SpvModule::is_capability_required(SpvCapability capability) const { + check_defined(); + if (contents->capabilities.find(capability) != contents->capabilities.end()) { + return true; + } + return false; +} + +SpvModule::EntryPointNames SpvModule::entry_point_names() const { + check_defined(); + SpvModule::EntryPointNames entry_point_names(contents->entry_points.size()); + for (const SpvModuleContents::EntryPoints::value_type &v : contents->entry_points) { + entry_point_names.push_back(v.first); + } + return entry_point_names; +} + +SpvId SpvModule::id() const { + check_defined(); + return contents->module_id; +} + +void SpvModule::check_defined() const { + user_assert(is_defined()) << "An SpvModule must be defined before accessing its properties\n"; +} + +void SpvModule::encode(SpvBinary &binary) const { + check_defined(); + + // 0. Encode the header + binary.push_back(SpvMagicNumber); + binary.push_back(SpvVersion); + binary.push_back(contents->source_language); + binary.push_back(0); // Bound placeholder (aka last id used) + binary.push_back(0); // Reserved for schema. + + // 1. Capabilities + for (const SpvCapability &capability : contents->capabilities) { + SpvInstruction inst = SpvFactory::capability(capability); + inst.encode(binary); + } + + // 2. Extensions + for (const std::string &extension : contents->extensions) { + SpvInstruction inst = SpvFactory::extension(extension); + inst.encode(binary); + } + + // 3. Extended Instruction Set Imports + for (const std::string &import : contents->imports) { + SpvInstruction inst = SpvFactory::import(import); + inst.encode(binary); + } + + // 4. Memory Model + SpvInstruction memory_model_inst = SpvFactory::memory_model(contents->addressing_model, contents->memory_model); + memory_model_inst.encode(binary); + + // 5. Entry Points + for (const SpvModuleContents::EntryPoints::value_type &value : contents->entry_points) { + SpvInstruction entry_point_inst = value.second; + entry_point_inst.encode(binary); + } + + // 6. Execution Modes + for (const SpvInstruction &inst : contents->execution_modes) { + inst.encode(binary); + } + + // 7. Debug + for (const SpvInstruction &inst : contents->debug) { + inst.encode(binary); + } + + // 8. Annotations + for (const SpvInstruction &inst : contents->annotations) { + inst.encode(binary); + } + + // 9a. Type Declarations + for (const SpvInstruction &inst : contents->types) { + inst.encode(binary); + } + + // 9b. Constants + for (const SpvInstruction &inst : contents->constants) { + inst.encode(binary); + } + + // 9c. Globals + for (const SpvInstruction &inst : contents->globals) { + inst.encode(binary); + } + + // 10-11. Function Declarations & Definitions + for (const SpvFunction &func : contents->functions) { + func.encode(binary); + } +} + +// -- + +SpvBuilder::SpvBuilder() { + SpvId module_id = declare_id(SpvModuleId); + module = SpvModule::make(module_id); +} + +SpvId SpvBuilder::reserve_id(SpvKind kind) { + return declare_id(kind); +} + +SpvId SpvBuilder::declare_id(SpvKind kind) { + // use type-agnostic non-overlapping increasing ids + SpvId item_id = kind_map.size() + 1; + kind_map[item_id] = kind; + return item_id; +} + +SpvKind SpvBuilder::kind_of(SpvId item_id) { + KindMap::const_iterator it = kind_map.find(item_id); + if (it != kind_map.end()) { + return SpvInvalidItem; + } + return it->second; +} + +void SpvBuilder::encode(SpvBinary &binary) const { + // Encode the module + module.encode(binary); +} + +SpvId SpvBuilder::map_type(const Type &type, uint32_t array_size) { + SpvId type_id = lookup_type(type, array_size); + if (type_id == SpvInvalidId) { + type_id = declare_type(type, array_size); + } + return type_id; +} + +SpvId SpvBuilder::map_pointer_type(const Type &type, SpvStorageClass storage_class) { + SpvId ptr_type_id = lookup_pointer_type(type, storage_class); + if (ptr_type_id == SpvInvalidId) { + ptr_type_id = declare_pointer_type(ptr_type_id, storage_class); + } + return ptr_type_id; +} + +SpvId SpvBuilder::map_pointer_type(SpvId type_id, SpvStorageClass storage_class) { + SpvId ptr_type_id = lookup_pointer_type(type_id, storage_class); + if (ptr_type_id == SpvInvalidId) { + ptr_type_id = declare_pointer_type(type_id, storage_class); + } + return ptr_type_id; +} + +SpvId SpvBuilder::map_function_type(SpvId return_type, const ParamTypes ¶m_types) { + SpvId type_id = lookup_function_type(return_type, param_types); + if (type_id == SpvInvalidId) { + type_id = declare_function_type(return_type, param_types); + } + return type_id; +} + +SpvId SpvBuilder::map_constant(const Type &type, const void *data) { + SpvId result_id = lookup_constant(type, data); + if (result_id == SpvInvalidId) { + result_id = declare_constant(type, data); + } + return result_id; +} + +void SpvBuilder::add_entry_point(const std::string &name, + SpvId func_id, SpvExecutionModel exec_model, + const Variables &variables) { + + SpvInstruction inst = SpvFactory::entry_point(exec_model, func_id, name, variables); + module.add_entry_point(name, inst); +} + +SpvFunction SpvBuilder::add_function(SpvId return_type_id, const ParamTypes ¶m_types) { + SpvId func_id = declare_id(SpvFunctionId); + SpvId func_type_id = map_function_type(return_type_id, param_types); + SpvFunction func = SpvFunction::make(func_type_id, func_id, return_type_id); + for (SpvId param_type_id : param_types) { + SpvId param_id = declare_id(SpvParameterId); + SpvInstruction param_inst = SpvFactory::function_parameter(param_type_id, param_id); + func.add_parameter(param_inst); + map_instruction(param_inst); + } + SpvId block_id = declare_id(SpvBlockId); + SpvBlock entry_block = SpvBlock::make(func, block_id); + func.add_block(entry_block); + module.add_function(func); + function_map[func_id] = func; + map_instruction(func.declaration()); + return func; +} + +SpvId SpvBuilder::add_global_variable(SpvId type_id, uint32_t storage_class, SpvId init_id) { + SpvId var_id = reserve_id(SpvVariableId); + module.add_global(SpvFactory::variable(var_id, type_id, storage_class, init_id)); + return var_id; +} + +SpvId SpvBuilder::add_variable(SpvId type_id, uint32_t storage_class, SpvId init_id) { + SpvId var_id = reserve_id(SpvVariableId); + current_block().add_variable(SpvFactory::variable(var_id, type_id, storage_class, init_id)); + return var_id; +} + +void SpvBuilder::add_annotation(SpvId target_id, SpvDecoration decoration_type, const Literals &literals) { + SpvInstruction inst = SpvFactory::decorate(target_id, decoration_type, literals); + current_module().add_annotation(inst); +} + +void SpvBuilder::add_struct_annotation(SpvId struct_type_id, uint32_t member_index, SpvDecoration decoration_type, const Literals &literals) { + SpvInstruction inst = SpvFactory::decorate_member(struct_type_id, member_index, decoration_type, literals); + current_module().add_annotation(inst); +} + +void SpvBuilder::add_execution_mode_local_size(SpvId func_id, + uint32_t wg_size_x, uint32_t wg_size_y, uint32_t wg_size_z) { + + wg_size_x = std::max(wg_size_x, (uint32_t)1); + wg_size_y = std::max(wg_size_y, (uint32_t)1); + wg_size_z = std::max(wg_size_z, (uint32_t)1); + + SpvInstruction exec_mode_inst = SpvFactory::exec_mode_local_size(func_id, wg_size_x, wg_size_y, wg_size_z); + module.add_execution_mode(exec_mode_inst); +} + +void SpvBuilder::enter_block(const SpvBlock &block) { + block_stack.push(block); +} + +SpvBlock SpvBuilder::current_block() const { + SpvBlock block; + if (!block_stack.empty()) { + block = block_stack.top(); + } + return block; +} + +SpvBlock SpvBuilder::leave_block() { + SpvBlock block; + if (!block_stack.empty()) { + block = block_stack.top(); + block_stack.pop(); + } + return block; +} + +SpvFunction SpvBuilder::lookup_function(SpvId func_id) const { + SpvFunction func; + FunctionMap::const_iterator it = function_map.find(func_id); + if (it != function_map.end()) { + func = it->second; + } + return func; +} + +void SpvBuilder::enter_function(const SpvFunction &func) { + function_stack.push(func); + enter_block(func.entry_block()); +} + +SpvFunction SpvBuilder::current_function() const { + SpvFunction func; + if (!function_stack.empty()) { + func = function_stack.top(); + } + return func; +} + +SpvFunction SpvBuilder::leave_function() { + SpvFunction func; + leave_block(); + if (!function_stack.empty()) { + func = function_stack.top(); + function_stack.pop(); + } + return func; +} + +void SpvBuilder::set_current_id(SpvId val) { + scope_id = val; +} + +SpvId SpvBuilder::current_id() const { + return scope_id; +} + +SpvModule SpvBuilder::current_module() const { + return module; +} + +void SpvBuilder::require_capability(SpvCapability capability) { + if (!module.is_capability_required(capability)) { + module.require_capability(capability); + } +} + +bool SpvBuilder::is_capability_required(SpvCapability capability) const { + return module.is_capability_required(capability); +} + +void SpvBuilder::require_extension(const std::string &extension) { + if (!module.is_extension_required(extension)) { + module.require_extension(extension); + } +} + +bool SpvBuilder::is_extension_required(const std::string &extension) const { + return module.is_extension_required(extension); +} + +SpvBuilder::TypeKey SpvBuilder::make_type_key(const Type &type, uint32_t array_size) const { + TypeKey key(4 + sizeof(uint32_t), ' '); + key[0] = type.code(); + key[1] = type.bits(); + key[2] = type.lanes() & 0xff; + key[3] = (type.lanes() >> 8) & 0xff; + for (size_t i = 0; i < sizeof(uint32_t); i++) { + key[i + 4] = (array_size & 0xff); + array_size >>= 8; + } + return key; +} + +SpvId SpvBuilder::lookup_type(const Type &type, uint32_t array_size) const { + SpvBuilder::TypeKey type_key = make_type_key(type, array_size); + TypeMap::const_iterator it = type_map.find(type_key); + if (it == type_map.end()) { + return SpvInvalidId; + } + return it->second; +} + +SpvId SpvBuilder::declare_type(const Type &type, uint32_t array_size) { + SpvBuilder::TypeKey type_key = make_type_key(type, array_size); + TypeMap::const_iterator it = type_map.find(type_key); + if (it != type_map.end()) { + return it->second; + } + + if (array_size > 1) { + SpvId array_type_id = declare_id(SpvArrayTypeId); + SpvId element_type_id = declare_type(type, 1); + SpvInstruction inst = SpvFactory::array_type(array_type_id, element_type_id, array_size); + module.add_type(inst); + type_map[type_key] = array_type_id; + return array_type_id; + } + + SpvId type_id = SpvInvalidId; + if (type.is_vector()) { + type_id = declare_id(SpvVectorTypeId); + SpvId element_type_id = declare_type(type.with_lanes(1)); + SpvInstruction inst = SpvFactory::vector_type(type_id, element_type_id, type.lanes()); + module.add_type(inst); + } else { + if (type.is_handle()) { + type_id = declare_id(SpvVoidTypeId); + SpvInstruction inst = SpvFactory::void_type(type_id); + module.add_type(inst); + } else if (type.is_bool()) { + type_id = declare_id(SpvBoolTypeId); + SpvInstruction inst = SpvFactory::bool_type(type_id); + module.add_type(inst); + } else if (type.is_float()) { + type_id = declare_id(SpvFloatTypeId); + SpvInstruction inst = SpvFactory::float_type(type_id, type.bits()); + module.add_type(inst); + } else if (type.is_int_or_uint()) { + type_id = declare_id(SpvIntTypeId); + SpvId signedness = type.is_uint() ? 0 : 1; + SpvInstruction inst = SpvFactory::integer_type(type_id, type.bits(), signedness); + module.add_type(inst); + } else { + internal_error << "SPIRV: Unsupported type " << type << "\n"; + } + } + + type_map[type_key] = type_id; + return type_id; +} + +SpvBuilder::TypeKey SpvBuilder::make_struct_type_key(const StructMemberTypes &member_type_ids) const { + TypeKey key(member_type_ids.size() * sizeof(SpvId), ' '); + uint32_t index = 0; + for (SpvId type_id : member_type_ids) { + for (size_t i = 0; i < sizeof(uint32_t); i++, index++) { + key[index] = (type_id & 0xff); + type_id >>= 8; + } + } + return key; +} + +SpvId SpvBuilder::lookup_struct(const StructMemberTypes &member_type_ids) const { + TypeKey key = make_struct_type_key(member_type_ids); + TypeMap::const_iterator it = struct_map.find(key); + if (it != struct_map.end()) { + return it->second; + } + return SpvInvalidId; +} + +SpvId SpvBuilder::declare_struct(const StructMemberTypes &member_type_ids) { + TypeKey key = make_struct_type_key(member_type_ids); + TypeMap::const_iterator it = struct_map.find(key); + if (it != struct_map.end()) { + return it->second; + } + + SpvId struct_type_id = declare_id(SpvStructTypeId); + SpvInstruction inst = SpvFactory::struct_type(struct_type_id, member_type_ids); + module.add_type(inst); + struct_map[key] = struct_type_id; + return struct_type_id; +} + +SpvBuilder::PointerTypeKey SpvBuilder::make_pointer_type_key(const Type &type, SpvStorageClass storage_class) const { + SpvId base_type_id = lookup_type(type); + if (base_type_id == SpvInvalidId) { + internal_error << "SPIRV: Attempted to declare pointer type for undeclared base type! " << type << "\n"; + } + return std::make_pair(base_type_id, storage_class); +} + +SpvBuilder::PointerTypeKey SpvBuilder::make_pointer_type_key(SpvId base_type_id, SpvStorageClass storage_class) const { + return std::make_pair(base_type_id, storage_class); +} + +SpvId SpvBuilder::lookup_pointer_type(const Type &type, SpvStorageClass storage_class) const { + SpvId base_type_id = lookup_type(type); + if (base_type_id == SpvInvalidId) { + internal_error << "SPIRV: Attempted to lookup pointer type for undeclared base type! " << type << "\n"; + } + return lookup_pointer_type(base_type_id, storage_class); +} + +SpvId SpvBuilder::lookup_pointer_type(SpvId base_type_id, SpvStorageClass storage_class) const { + PointerTypeKey key = make_pointer_type_key(base_type_id, storage_class); + PointerTypeMap::const_iterator it = pointer_type_map.find(key); + if (it != pointer_type_map.end()) { + return it->second; + } + return SpvInvalidId; +} + +SpvId SpvBuilder::declare_pointer_type(const Type &type, SpvStorageClass storage_class) { + SpvId base_type_id = map_type(type); + return declare_pointer_type(base_type_id, storage_class); +} + +SpvId SpvBuilder::declare_pointer_type(SpvId base_type_id, SpvStorageClass storage_class) { + PointerTypeKey key = make_pointer_type_key(base_type_id, storage_class); + PointerTypeMap::const_iterator it = pointer_type_map.find(key); + if (it != pointer_type_map.end()) { + return it->second; + } + + SpvId pointer_type_id = declare_id(SpvPointerTypeId); + SpvInstruction inst = SpvFactory::pointer_type(pointer_type_id, storage_class, base_type_id); + module.add_type(inst); + pointer_type_map[key] = pointer_type_id; + return pointer_type_id; +} + +SpvBuilder::ConstantKey SpvBuilder::make_constant_key(const Type &type, const void *data) const { + ConstantKey key(type.bytes() + 4, ' '); + key[0] = type.code(); + key[1] = type.bits(); + key[2] = type.lanes() & 0xff; + key[3] = (type.lanes() >> 8) & 0xff; + const char *data_char = (const char *)data; + for (int i = 0; i < type.bytes(); i++) { + key[i + 4] = data_char[i]; + } + return key; +} + +SpvBuilder::ConstantKey SpvBuilder::make_bool_constant_key(bool value) const { + Type type = Bool(); + bool data = value; + return make_constant_key(type, &data); +} + +SpvBuilder::ConstantKey SpvBuilder::make_null_constant_key(const Type &type) const { + ConstantKey key(type.bytes() + 4, ' '); + key[0] = type.code(); + key[1] = type.bits(); + key[2] = type.lanes() & 0xff; + key[3] = (type.lanes() >> 8) & 0xff; + for (int i = 0; i < type.bytes(); i++) { + key[i + 4] = 0; + } + return key; +} + +SpvId SpvBuilder::lookup_null_constant(const Type &type) const { + ConstantKey key = make_null_constant_key(type); + ConstantMap::const_iterator it = constant_map.find(key); + if (it != constant_map.end()) { + return it->second; + } + return SpvInvalidId; +} + +SpvId SpvBuilder::declare_null_constant(const Type &type) { + ConstantKey key = make_null_constant_key(type); + ConstantMap::const_iterator it = constant_map.find(key); + if (it != constant_map.end()) { + return it->second; + } + + SpvId result_id = declare_id(SpvConstantId); + SpvId type_id = declare_type(type); + SpvInstruction inst = SpvFactory::null_constant(result_id, type_id); + module.add_constant(inst); + constant_map[key] = result_id; + return result_id; +} + +SpvId SpvBuilder::declare_bool_constant(bool value) { + const std::string key = make_bool_constant_key(value); + ConstantMap::const_iterator it = constant_map.find(key); + if (it != constant_map.end()) { + return it->second; + } + + debug(3) << "declare_bool_constant for " << value << "\n"; + + Type type = Bool(); + SpvId result_id = declare_id(SpvBoolConstantId); + SpvId type_id = declare_type(type); + SpvInstruction inst = SpvFactory::bool_constant(result_id, type_id, value); + module.add_constant(inst); + constant_map[key] = result_id; + return result_id; +} + +SpvId SpvBuilder::declare_scalar_constant(const Type &scalar_type, const void *data) { + if (scalar_type.lanes() != 1) { + internal_error << "SPIRV: Invalid type provided for scalar constant!" << scalar_type << "\n"; + return SpvInvalidId; + } + + const std::string constant_key = make_constant_key(scalar_type, data); + ConstantMap::const_iterator it = constant_map.find(constant_key); + if (it != constant_map.end()) { + return it->second; + } + + if (scalar_type.is_bool() && data) { + bool value = *reinterpret_cast(data); + return declare_bool_constant(value); + } + + debug(3) << "declare_scalar_constant for type " << scalar_type << "\n"; + + SpvId result_id = SpvInvalidId; + if (scalar_type.is_float()) { + result_id = declare_id(SpvFloatConstantId); + } else if (scalar_type.is_bool()) { + result_id = declare_id(SpvBoolConstantId); + } else if (scalar_type.is_int_or_uint()) { + result_id = declare_id(SpvIntConstantId); + } else { + internal_error << "SPIRV: Unsupported type:" << scalar_type << "\n"; + return SpvInvalidId; + } + + SpvId type_id = declare_type(scalar_type); + SpvInstruction inst = SpvFactory::constant(result_id, type_id, scalar_type.bytes(), data); + module.add_constant(inst); + constant_map[constant_key] = result_id; + return result_id; +} + +SpvId SpvBuilder::declare_vector_constant(const Type &type, const void *data) { + if (type.lanes() == 1) { + internal_error << "SPIRV: Invalid type provided for vector constant!" << type << "\n"; + return SpvInvalidId; + } + + const std::string key = make_constant_key(type, data); + ConstantMap::const_iterator it = constant_map.find(key); + if (it != constant_map.end()) { + return it->second; + } + + Type scalar_type = type.with_lanes(1); + std::vector components(type.lanes()); + if (scalar_type.is_float()) { + if (type.bits() == 64) { + const double *values = (const double *)data; + for (int c = 0; c < type.lanes(); c++) { + const double *entry = &(values[c]); + SpvId scalar_id = declare_scalar_constant(scalar_type, (const void *)entry); + components.push_back(scalar_id); + } + } else { + const float *values = (const float *)data; + for (int c = 0; c < type.lanes(); c++) { + const float *entry = &(values[c]); + SpvId scalar_id = declare_scalar_constant(scalar_type, (const void *)entry); + components.push_back(scalar_id); + } + } + } else if (scalar_type.is_bool()) { + const bool *values = (const bool *)data; + for (int c = 0; c < type.lanes(); c++) { + const bool *entry = &(values[c]); + SpvId scalar_id = declare_scalar_constant(scalar_type, (const void *)entry); + components.push_back(scalar_id); + } + } else if (scalar_type.is_int_or_uint()) { + if (type.bits() == 64) { + const uint64_t *values = (const uint64_t *)data; + for (int c = 0; c < type.lanes(); c++) { + const uint64_t *entry = &(values[c]); + SpvId scalar_id = declare_scalar_constant(scalar_type, (const void *)entry); + components.push_back(scalar_id); + } + } else { + const uint32_t *values = (const uint32_t *)data; + for (int c = 0; c < type.lanes(); c++) { + const uint32_t *entry = &(values[c]); + SpvId scalar_id = declare_scalar_constant(scalar_type, (const void *)entry); + components.push_back(scalar_id); + } + } + } else { + internal_error << "SPIRV: Unsupported type:" << type << "\n"; + return SpvInvalidId; + } + + SpvId result_id = declare_id(SpvCompositeConstantId); + SpvId type_id = declare_type(type); + SpvInstruction inst = SpvFactory::composite_constant(result_id, type_id, components); + module.add_constant(inst); + constant_map[key] = result_id; + return result_id; +} + +SpvId SpvBuilder::lookup_constant(const Type &type, const void *data) const { + ConstantKey key = make_constant_key(type, data); + ConstantMap::const_iterator it = constant_map.find(key); + if (it != constant_map.end()) { + return it->second; + } + return SpvInvalidId; +} + +SpvId SpvBuilder::declare_constant(const Type &type, const void *data) { + + const std::string key = make_constant_key(type, data); + ConstantMap::const_iterator it = constant_map.find(key); + if (it != constant_map.end()) { + return it->second; + } + + debug(3) << "declare_constant for type " << type << "\n"; + if (type.lanes() == 1) { + return declare_scalar_constant(type, data); + } else { + return declare_vector_constant(type, data); + } +} + +SpvId SpvBuilder::declare_access_chain(SpvId ptr_type_id, SpvId base_id, SpvId element_id, const Indices &indices) { + SpvId access_chain_id = declare_id(SpvAccessChainId); + append(SpvFactory::in_bounds_access_chain(ptr_type_id, access_chain_id, base_id, element_id, indices)); + return access_chain_id; +} + +SpvId SpvBuilder::map_instruction(const SpvInstruction &inst) { + const SpvId key = inst.result_id(); + if (instruction_map.find(key) == instruction_map.end()) { + instruction_map.insert({key, inst}); + } else { + instruction_map[key] = inst; + } + return key; +} + +SpvInstruction SpvBuilder::lookup_instruction(SpvId result_id) const { + InstructionMap::const_iterator it = instruction_map.find(result_id); + if (it == instruction_map.end()) { + return SpvInstruction(); + } + return it->second; +} + +SpvBuilder::FunctionTypeKey SpvBuilder::make_function_type_key(SpvId return_type_id, const ParamTypes ¶m_type_ids) const { + TypeKey key((1 + param_type_ids.size()) * sizeof(SpvId), ' '); + + uint32_t index = 0; + for (size_t i = 0; i < sizeof(uint32_t); i++, index++) { + key[index] = (return_type_id & 0xff); + return_type_id >>= 8; + } + for (SpvId type_id : param_type_ids) { + for (size_t i = 0; i < sizeof(uint32_t); i++, index++) { + key[index] = (type_id & 0xff); + type_id >>= 8; + } + } + return key; +} + +SpvId SpvBuilder::lookup_function_type(SpvId return_type_id, const ParamTypes ¶m_type_ids) const { + FunctionTypeKey key = make_function_type_key(return_type_id, param_type_ids); + FunctionTypeMap::const_iterator it = function_type_map.find(key); + if (it != function_type_map.end()) { + return it->second; + } + return SpvInvalidId; +} + +SpvId SpvBuilder::declare_function_type(SpvId return_type_id, const ParamTypes ¶m_type_ids) { + FunctionTypeKey func_type_key = make_function_type_key(return_type_id, param_type_ids); + FunctionTypeMap::const_iterator it = function_type_map.find(func_type_key); + if (it != function_type_map.end()) { + return it->second; + } + + SpvId function_type_id = declare_id(SpvFunctionTypeId); + SpvInstruction inst = SpvFactory::function_type(function_type_id, return_type_id, param_type_ids); + module.add_type(inst); + function_type_map[func_type_key] = function_type_id; + return function_type_id; +} + +SpvId SpvBuilder::declare_runtime_array(SpvId base_type_id) { + SpvId runtime_array_id = declare_id(SpvRuntimeArrayTypeId); + SpvInstruction inst = SpvFactory::runtime_array_type(runtime_array_id, base_type_id); + module.add_type(inst); + return runtime_array_id; +} + +void SpvBuilder::append(SpvInstruction inst) { + if (!block_stack.empty()) { + current_block().add_instruction(std::move(inst)); + } else { + internal_error << "SPIRV: Current block undefined! Unable to append!\n"; + } +} + +// -- + +// -- Factory Methods for Specific Instructions + +SpvInstruction SpvFactory::label(SpvId result_id) { + SpvInstruction inst = SpvInstruction::make(SpvOpLabel); + inst.set_result_id(result_id); + return inst; +} + +SpvInstruction SpvFactory::decorate(SpvId target_id, SpvDecoration decoration_type, const SpvFactory::Literals &literals) { + SpvInstruction inst = SpvInstruction::make(SpvOpDecorate); + inst.add_operand(target_id); + inst.add_immediate(decoration_type); + for (uint32_t l : literals) { + inst.add_immediate(l); + } + return inst; +} + +SpvInstruction SpvFactory::decorate_member(SpvId struct_type_id, uint32_t member_index, SpvDecoration decoration_type, const SpvFactory::Literals &literals) { + SpvInstruction inst = SpvInstruction::make(SpvOpMemberDecorate); + inst.add_operand(struct_type_id); + inst.add_immediate(decoration_type); + for (uint32_t l : literals) { + inst.add_immediate(l); + } + return inst; +} + +SpvInstruction SpvFactory::unary_op(SpvOp op_code, SpvId type_id, SpvId result_id, SpvId src_id) { + SpvInstruction inst = SpvInstruction::make(op_code); + inst.set_type_id(type_id); + inst.set_result_id(result_id); + inst.add_operand(src_id); + return inst; +} + +SpvInstruction SpvFactory::binary_op(SpvOp op_code, SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id) { + SpvInstruction inst = SpvInstruction::make(op_code); + inst.set_type_id(type_id); + inst.set_result_id(result_id); + inst.add_operand(src_a_id); + inst.add_operand(src_b_id); + return inst; +} + +SpvInstruction SpvFactory::convert(SpvOp op_code, SpvId type_id, SpvId result_id, SpvId src_id) { + SpvInstruction inst = SpvInstruction::make(op_code); + inst.set_type_id(type_id); + inst.set_result_id(result_id); + inst.add_operand(src_id); + return inst; +} + +SpvInstruction SpvFactory::void_type(SpvId void_type_id) { + SpvInstruction inst = SpvInstruction::make(SpvOpTypeVoid); + inst.set_result_id(void_type_id); + return inst; +} + +SpvInstruction SpvFactory::bool_type(SpvId bool_type_id) { + SpvInstruction inst = SpvInstruction::make(SpvOpTypeBool); + inst.set_result_id(bool_type_id); + return inst; +} + +SpvInstruction SpvFactory::integer_type(SpvId int_type_id, uint32_t bits, uint32_t signedness) { + SpvInstruction inst = SpvInstruction::make(SpvOpTypeInt); + inst.set_result_id(int_type_id); + inst.add_immediate(bits); + inst.add_immediate(signedness); + return inst; +} + +SpvInstruction SpvFactory::float_type(SpvId float_type_id, uint32_t bits) { + SpvInstruction inst = SpvInstruction::make(SpvOpTypeFloat); + inst.set_result_id(float_type_id); + inst.add_immediate(bits); + return inst; +} + +SpvInstruction SpvFactory::vector_type(SpvId vector_type_id, SpvId element_type_id, uint32_t vector_size) { + SpvInstruction inst = SpvInstruction::make(SpvOpTypeVector); + inst.set_result_id(vector_type_id); + inst.add_operand(element_type_id); + inst.add_immediate(vector_size); + return inst; +} + +SpvInstruction SpvFactory::array_type(SpvId array_type_id, SpvId element_type_id, uint32_t array_size) { + SpvInstruction inst = SpvInstruction::make(SpvOpTypeArray); + inst.set_result_id(array_type_id); + inst.add_operand(element_type_id); + inst.add_immediate(array_size); + return inst; +} + +SpvInstruction SpvFactory::struct_type(SpvId result_id, const SpvFactory::MemberTypeIds &member_type_ids) { + SpvInstruction inst = SpvInstruction::make(SpvOpTypeStruct); + inst.set_result_id(result_id); + for (const SpvId member_type : member_type_ids) { + inst.add_operand(member_type); + } + return inst; +} + +SpvInstruction SpvFactory::runtime_array_type(SpvId result_type_id, SpvId base_type_id) { + SpvInstruction inst = SpvInstruction::make(SpvOpTypeRuntimeArray); + inst.set_result_id(result_type_id); + inst.add_operand(base_type_id); + return inst; +} + +SpvInstruction SpvFactory::pointer_type(SpvId pointer_type_id, SpvStorageClass storage_class, SpvId base_type_id) { + SpvInstruction inst = SpvInstruction::make(SpvOpTypePointer); + inst.set_result_id(pointer_type_id); + inst.add_immediate(storage_class); + inst.add_operand(base_type_id); + return inst; +} + +SpvInstruction SpvFactory::function_type(SpvId function_type_id, SpvId return_type_id, const SpvFactory::ParamTypes ¶m_type_ids) { + SpvInstruction inst = SpvInstruction::make(SpvOpTypeFunction); + inst.set_type_id(return_type_id); + inst.set_result_id(function_type_id); + for (SpvId type_id : param_type_ids) { + inst.add_operand(type_id); + } + return inst; +} + +SpvInstruction SpvFactory::constant(SpvId result_id, SpvId type_id, size_t bytes, const void *data) { + SpvInstruction inst = SpvInstruction::make(SpvOpConstant); + inst.set_type_id(type_id); + inst.set_result_id(result_id); + inst.add_data(bytes, data); + return inst; +} + +SpvInstruction SpvFactory::null_constant(SpvId result_id, SpvId type_id) { + SpvInstruction inst = SpvInstruction::make(SpvOpConstantNull); + inst.set_type_id(type_id); + inst.set_result_id(result_id); + return inst; +} + +SpvInstruction SpvFactory::bool_constant(SpvId result_id, SpvId type_id, bool value) { + SpvOp op_code = value ? SpvOpConstantTrue : SpvOpConstantFalse; + SpvInstruction inst = SpvInstruction::make(op_code); + inst.set_type_id(type_id); + inst.set_result_id(result_id); + return inst; +} + +SpvInstruction SpvFactory::composite_constant(SpvId result_id, SpvId type_id, const SpvFactory::Components &components) { + SpvInstruction inst = SpvInstruction::make(SpvOpConstantComposite); + inst.set_type_id(type_id); + inst.set_result_id(result_id); + for (SpvId scalar_id : components) { + inst.add_operand(scalar_id); + } + return inst; +} + +SpvInstruction SpvFactory::variable(SpvId result_id, SpvId result_type_id, uint32_t storage_class, SpvId initializer_id) { + SpvInstruction inst = SpvInstruction::make(SpvOpVariable); + inst.set_type_id(result_type_id); + inst.set_result_id(result_id); + inst.add_immediate(storage_class); + if (initializer_id != SpvInvalidId) { + inst.add_operand(initializer_id); + } + return inst; +} + +SpvInstruction SpvFactory::function(SpvId return_type_id, SpvId func_id, uint32_t control_mask, SpvId func_type_id) { + SpvInstruction inst = SpvInstruction::make(SpvOpFunction); + inst.set_type_id(return_type_id); + inst.set_result_id(func_id); + inst.add_immediate(control_mask); + inst.add_operand(func_type_id); + return inst; +} + +SpvInstruction SpvFactory::function_parameter(SpvId param_type_id, SpvId param_id) { + SpvInstruction inst = SpvInstruction::make(SpvOpFunctionParameter); + inst.set_type_id(param_type_id); + inst.set_result_id(param_id); + return inst; +} + +SpvInstruction SpvFactory::function_end() { + SpvInstruction inst = SpvInstruction::make(SpvOpFunctionEnd); + return inst; +} + +SpvInstruction SpvFactory::return_stmt(SpvId return_value_id) { + SpvOp opcode = (return_value_id == SpvInvalidId) ? SpvOpReturn : SpvOpReturnValue; + SpvInstruction inst = SpvInstruction::make(opcode); + if (return_value_id != SpvInvalidId) { + inst.add_operand(return_value_id); + } + return inst; +} + +SpvInstruction SpvFactory::entry_point(SpvId exec_model, SpvId func_id, const std::string &name, const SpvFactory::Variables &variables) { + SpvInstruction inst = SpvInstruction::make(SpvOpEntryPoint); + inst.add_immediate(exec_model); + inst.add_operand(func_id); + inst.add_string(name); + for (SpvId var : variables) { + inst.add_operand(var); + } + return inst; +} + +SpvInstruction SpvFactory::memory_model(SpvAddressingModel addressing_model, SpvMemoryModel memory_model) { + SpvInstruction inst = SpvInstruction::make(SpvOpMemoryModel); + inst.add_immediate(addressing_model); + inst.add_immediate(memory_model); + return inst; +} + +SpvInstruction SpvFactory::exec_mode_local_size(SpvId function_id, uint32_t wg_size_x, uint32_t wg_size_y, uint32_t wg_size_z) { + SpvInstruction inst = SpvInstruction::make(SpvOpExecutionMode); + inst.add_operand(function_id); + inst.add_immediate(SpvExecutionModeLocalSize); + inst.add_immediate(wg_size_x); + inst.add_immediate(wg_size_y); + inst.add_immediate(wg_size_z); + return inst; +} + +SpvInstruction SpvFactory::control_barrier(SpvId execution_scope_id, SpvId memory_scope_id, uint32_t semantics_mask) { + SpvInstruction inst = SpvInstruction::make(SpvOpControlBarrier); + inst.add_operand(execution_scope_id); + inst.add_operand(memory_scope_id); + inst.add_immediate(semantics_mask); + return inst; +} + +SpvInstruction SpvFactory::logical_not(SpvId type_id, SpvId result_id, SpvId src_id) { + return unary_op(SpvOpNot, type_id, result_id, src_id); +} + +SpvInstruction SpvFactory::multiply_extended(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, bool is_signed) { + return binary_op(is_signed ? SpvOpSMulExtended : SpvOpUMulExtended, type_id, result_id, src_a_id, src_b_id); +} + +SpvInstruction SpvFactory::select(SpvId type_id, SpvId result_id, SpvId condition_id, SpvId true_id, SpvId false_id) { + SpvInstruction inst = SpvInstruction::make(SpvOpSelect); + inst.set_type_id(type_id); + inst.set_result_id(result_id); + inst.add_operand(condition_id); + inst.add_operand(true_id); + inst.add_operand(false_id); + return inst; +} + +SpvInstruction SpvFactory::in_bounds_access_chain(SpvId type_id, SpvId result_id, SpvId base_id, SpvId element_id, const SpvFactory::Indices &indices) { + SpvInstruction inst = SpvInstruction::make(SpvOpInBoundsAccessChain); + inst.set_type_id(type_id); + inst.set_result_id(result_id); + inst.add_operand(base_id); + inst.add_operand(element_id); + for (SpvId i : indices) { + inst.add_operand(i); + } + return inst; +} + +SpvInstruction SpvFactory::load(SpvId type_id, SpvId result_id, SpvId ptr_id, uint32_t access_mask) { + SpvInstruction inst = SpvInstruction::make(SpvOpLoad); + inst.set_type_id(type_id); + inst.set_result_id(result_id); + inst.add_operand(ptr_id); + inst.add_immediate(access_mask); + return inst; +} + +SpvInstruction SpvFactory::store(SpvId ptr_id, SpvId obj_id, uint32_t access_mask) { + SpvInstruction inst = SpvInstruction::make(SpvOpStore); + inst.add_operand(ptr_id); + inst.add_operand(obj_id); + inst.add_immediate(access_mask); + return inst; +} + +SpvInstruction SpvFactory::composite_extract(SpvId type_id, SpvId result_id, SpvId composite_id, const SpvFactory::Indices &indices) { + SpvInstruction inst = SpvInstruction::make(SpvOpCompositeExtract); + inst.set_type_id(type_id); + inst.set_result_id(result_id); + inst.add_operand(composite_id); + for (SpvId i : indices) { + inst.add_immediate(i); + } + return inst; +} + +SpvInstruction SpvFactory::vector_insert_dynamic(SpvId result_id, SpvId vector_id, SpvId value_id, uint32_t index) { + SpvInstruction inst = SpvInstruction::make(SpvOpVectorInsertDynamic); + inst.set_type_id(SpvOpTypeVector); + inst.set_result_id(result_id); + inst.add_operand(vector_id); + inst.add_operand(value_id); + inst.add_immediate(index); + return inst; +} + +SpvInstruction SpvFactory::bitcast(SpvId type_id, SpvId result_id, SpvId src_id) { + SpvInstruction inst = SpvInstruction::make(SpvOpBitcast); + inst.set_type_id(type_id); + inst.set_result_id(result_id); + inst.add_operand(src_id); + return inst; +} + +SpvInstruction SpvFactory::integer_add(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id) { + return binary_op(SpvOpIAdd, type_id, result_id, src_a_id, src_b_id); +} + +SpvInstruction SpvFactory::branch(SpvId target_label_id) { + SpvInstruction inst = SpvInstruction::make(SpvOpBranch); + inst.add_operand(target_label_id); + return inst; +} + +SpvInstruction SpvFactory::conditional_branch(SpvId condition_label_id, SpvId true_label_id, SpvId false_label_id, const SpvFactory::BranchWeights &weights) { + SpvInstruction inst = SpvInstruction::make(SpvOpBranch); + inst.add_operand(condition_label_id); + inst.add_operand(true_label_id); + inst.add_operand(false_label_id); + for (uint32_t w : weights) { + inst.add_immediate(w); + } + return inst; +} + +SpvInstruction SpvFactory::loop_merge(SpvId merge_label_id, SpvId continue_label_id, uint32_t loop_control_mask) { + SpvInstruction inst = SpvInstruction::make(SpvOpLoopMerge); + inst.add_operand(merge_label_id); + inst.add_operand(continue_label_id); + inst.add_immediate(loop_control_mask); + return inst; +} + +SpvInstruction SpvFactory::selection_merge(SpvId merge_label_id, uint32_t selection_control_mask) { + SpvInstruction inst = SpvInstruction::make(SpvOpSelectionMerge); + inst.add_operand(merge_label_id); + inst.add_immediate(selection_control_mask); + return inst; +} + +SpvInstruction SpvFactory::phi(SpvId type_id, SpvId result_id, const SpvFactory::BlockVariables &block_vars) { + SpvInstruction inst = SpvInstruction::make(SpvOpPhi); + inst.set_type_id(type_id); + inst.set_result_id(result_id); + for (const SpvFactory::VariableBlockIdPair &vb : block_vars) { + inst.add_operand(vb.first); // variable id + inst.add_operand(vb.second); // block id + } + return inst; +} + +SpvInstruction SpvFactory::capability(const SpvCapability &capability) { + SpvInstruction inst = SpvInstruction::make(SpvOpCapability); + inst.add_immediate(capability); + return inst; +} + +SpvInstruction SpvFactory::extension(const std::string &extension) { + SpvInstruction inst = SpvInstruction::make(SpvOpExtension); + inst.add_string(extension); + return inst; +} + +SpvInstruction SpvFactory::import(const std::string &import) { + SpvInstruction inst = SpvInstruction::make(SpvOpExtInstImport); + inst.add_string(import); + return inst; +} + +/** Specializations for reference counted classes */ +template<> +RefCount &ref_count(const SpvInstructionContents *c) noexcept { + return c->ref_count; +} + +template<> +void destroy(const SpvInstructionContents *c) { + delete c; +} + +template<> +RefCount &ref_count(const SpvBlockContents *c) noexcept { + return c->ref_count; +} + +template<> +void destroy(const SpvBlockContents *c) { + delete c; +} + +template<> +RefCount &ref_count(const SpvFunctionContents *c) noexcept { + return c->ref_count; +} + +template<> +void destroy(const SpvFunctionContents *c) { + delete c; +} + +template<> +RefCount &ref_count(const SpvModuleContents *c) noexcept { + return c->ref_count; +} + +template<> +void destroy(const SpvModuleContents *c) { + delete c; +} + +} // namespace Internal +} // namespace Halide + +#endif // WITH_SPIRV + +namespace Halide { +namespace Internal { + +void spirv_ir_test() { + +#ifdef WITH_SPIRV + SpvBinary binary; + SpvInstruction label_inst = SpvFactory::label(777); + assert(label_inst.result_id() == 777); + assert(label_inst.op_code() == SpvOpLabel); + label_inst.encode(binary); + assert(binary.size() == 2); // encodes to 2x 32-bit words [Length|OpCode, ResultId] + + SpvBuilder builder; + SpvId void_type_id = builder.reserve_id(SpvVoidTypeId); + SpvInstruction void_inst = SpvFactory::void_type(void_type_id); + builder.current_module().add_type(void_inst); + + SpvId int_type_id = builder.map_type(Int(32)); + SpvId uint_type_id = builder.map_type(UInt(32)); + SpvId float_type_id = builder.map_type(Float(32)); + + SpvBuilder::ParamTypes param_types = {int_type_id, uint_type_id, float_type_id}; + SpvFunction function = builder.add_function(void_type_id, param_types); + + builder.enter_function(function); + SpvId intrinsic_type_id = builder.map_type(Type(Type::UInt, 32, 3)); + SpvId intrinsic_id = builder.add_global_variable(intrinsic_type_id, SpvStorageClassInput); + + SpvId output_type_id = builder.map_type(Type(Type::UInt, 32, 1)); + SpvId output_id = builder.add_global_variable(output_type_id, SpvStorageClassOutput); + + SpvBuilder::Variables entry_point_variables; + entry_point_variables.push_back(intrinsic_id); + entry_point_variables.push_back(output_id); + builder.add_entry_point("entry_func", function.id(), SpvExecutionModelKernel, entry_point_variables); + + SpvBuilder::Literals annotation_literals = {SpvBuiltInWorkgroupId}; + builder.add_annotation(intrinsic_id, SpvDecorationBuiltIn, annotation_literals); + + SpvId intrinsic_loaded_id = builder.reserve_id(); + builder.append(SpvFactory::load(intrinsic_type_id, intrinsic_loaded_id, intrinsic_id)); + + float float_value = 32.0f; + SpvId float_src_id = builder.declare_constant(Float(32), &float_value); + SpvId converted_value_id = builder.reserve_id(SpvResultId); + builder.append(SpvFactory::convert(SpvOpConvertFToU, uint_type_id, converted_value_id, float_src_id)); + builder.append(SpvFactory::store(output_id, converted_value_id)); + builder.leave_function(); + + binary.clear(); + builder.encode(binary); + + std::cout << "SpirV IR test passed" << std::endl; +#else + std::cout << "SpirV IR test *disabled*" << std::endl; +#endif +} + +} // namespace Internal +} // namespace Halide diff --git a/src/SpirvIR.h b/src/SpirvIR.h new file mode 100644 index 000000000000..0c3356162820 --- /dev/null +++ b/src/SpirvIR.h @@ -0,0 +1,570 @@ +#ifndef HALIDE_SPIRV_IR_H +#define HALIDE_SPIRV_IR_H + +/** \file + * Defines methods for constructing and encoding instructions into the Khronos + * format specification known as the Standard Portable Intermediate Representation + * for Vulkan (SPIR-V). These class interfaces adopt Halide's conventions for its + * own IR, but is implemented as a stand-alone optional component that can be + * enabled as required for certain runtimes (eg Vulkan). + * + * NOTE: This file is only used internally for CodeGen! *DO NOT* add this file + * to the list of exported Halide headers in the src/CMakeFiles.txt or the + * top level Makefile. + */ +#ifdef WITH_SPIRV + +#include +#include +#include +#include +#include + +#include "IntrusivePtr.h" +#include "Type.h" + +#include // Use v1.0 spec as the minimal viable version (for maximum compatiblity) + +namespace Halide { +namespace Internal { + +/** Precision requirment for return values */ +enum SpvPrecision { + SpvFullPrecision, + SpvRelaxedPrecision, +}; + +/** Specific types of predefined constants */ +enum SpvPredefinedConstant { + SpvNullConstant, + SpvTrueConstant, + SpvFalseConstant, +}; + +/** Specific types of SPIR-V object ids */ +enum SpvKind { + SpvInvalidItem, + SpvTypeId, + SpvVoidTypeId, + SpvBoolTypeId, + SpvIntTypeId, + SpvFloatTypeId, + SpvVectorTypeId, + SpvArrayTypeId, + SpvRuntimeArrayTypeId, + SpvStringTypeId, + SpvPointerTypeId, + SpvStructTypeId, + SpvFunctionTypeId, + SpvAccessChainId, + SpvConstantId, + SpvBoolConstantId, + SpvIntConstantId, + SpvFloatConstantId, + SpvStringConstantId, + SpvCompositeConstantId, + SpvResultId, + SpvVariableId, + SpvInstructionId, + SpvFunctionId, + SpvBlockId, + SpvLabelId, + SpvParameterId, + SpvModuleId, + SpvUnknownItem, +}; + +/** SPIR-V requires all IDs to be 32-bit unsigned integers */ +using SpvId = uint32_t; +using SpvBinary = std::vector; + +static constexpr SpvId SpvInvalidId = SpvId(-1); +static constexpr SpvId SpvNoResult = 0; +static constexpr SpvId SpvNoType = 0; + +/** Pre-declarations for SPIR-V IR classes */ +class SpvModule; +class SpvFunction; +class SpvBlock; +class SpvInstruction; +class SpvBuilder; + +/** Pre-declarations for SPIR-V IR data structures */ +struct SpvModuleContents; +struct SpvFunctionContents; +struct SpvBlockContents; +struct SpvInstructionContents; + +/** Intrusive pointer types for SPIR-V IR data */ +using SpvModuleContentsPtr = IntrusivePtr; +using SpvFunctionContentsPtr = IntrusivePtr; +using SpvBlockContentsPtr = IntrusivePtr; +using SpvInstructionContentsPtr = IntrusivePtr; + +/** General interface for representing a SPIR-V Instruction */ +class SpvInstruction { +public: + SpvInstruction() = default; + ~SpvInstruction() = default; + + SpvInstruction(const SpvInstruction &) = default; + SpvInstruction &operator=(const SpvInstruction &) = default; + SpvInstruction(SpvInstruction &&) = default; + SpvInstruction &operator=(SpvInstruction &&) = default; + + void set_block(SpvBlock block); + void set_result_id(SpvId id); + void set_type_id(SpvId id); + void set_op_code(SpvOp opcode); + void add_operand(SpvId id); + void add_immediate(SpvId id); + void add_data(uint32_t bytes, const void *data); + void add_string(const std::string &str); + + SpvId result_id() const; + SpvId type_id() const; + SpvOp op_code() const; + SpvId operand(uint32_t index); + + bool has_type() const; + bool has_result() const; + bool is_defined() const; + bool is_immediate(uint32_t index) const; + uint32_t length() const; + SpvBlock block() const; + void check_defined() const; + + void encode(SpvBinary &binary) const; + + static SpvInstruction make(SpvOp op_code); + +protected: + SpvInstructionContentsPtr contents; +}; + +/** General interface for representing a SPIR-V Block */ +class SpvBlock { +public: + using Instructions = std::vector; + using Variables = std::vector; + using Blocks = std::vector; + + SpvBlock() = default; + ~SpvBlock() = default; + + SpvBlock(const SpvBlock &) = default; + SpvBlock &operator=(const SpvBlock &) = default; + SpvBlock(SpvBlock &&) = default; + SpvBlock &operator=(SpvBlock &&) = default; + + void add_instruction(SpvInstruction inst); + void add_variable(SpvInstruction var); + void set_function(SpvFunction func); + const Instructions &instructions() const; + const Variables &variables() const; + SpvFunction function() const; + bool is_reachable() const; + bool is_terminated() const; + bool is_defined() const; + SpvId id() const; + void check_defined() const; + + void encode(SpvBinary &binary) const; + + static SpvBlock make(SpvFunction func, SpvId id); + +protected: + SpvBlockContentsPtr contents; +}; + +/** General interface for representing a SPIR-V Function */ +class SpvFunction { +public: + SpvFunction() = default; + ~SpvFunction() = default; + + SpvFunction(const SpvFunction &) = default; + SpvFunction &operator=(const SpvFunction &) = default; + SpvFunction(SpvFunction &&) = default; + SpvFunction &operator=(SpvFunction &&) = default; + + void add_block(const SpvBlock &block); + void add_parameter(const SpvInstruction ¶m); + void set_module(SpvModule module); + void set_return_precision(SpvPrecision precision); + void set_parameter_precision(uint32_t index, SpvPrecision precision); + bool is_defined() const; + + SpvBlock entry_block() const; + SpvPrecision return_precision() const; + SpvPrecision parameter_precision(uint32_t index) const; + uint32_t parameter_count() const; + uint32_t control_mask() const; + SpvInstruction declaration() const; + SpvModule module() const; + SpvId return_type_id() const; + SpvId type_id() const; + SpvId id() const; + void check_defined() const; + + void encode(SpvBinary &binary) const; + + static SpvFunction make(SpvId func_id, SpvId func_type_id, SpvId return_type_id, uint32_t control_mask = SpvFunctionControlMaskNone); + +protected: + SpvFunctionContentsPtr contents; +}; + +/** General interface for representing a SPIR-V code module */ +class SpvModule { +public: + using EntryPointNames = std::vector; + using Instructions = std::vector; + + SpvModule() = default; + ~SpvModule() = default; + + SpvModule(const SpvModule &) = default; + SpvModule &operator=(const SpvModule &) = default; + SpvModule(SpvModule &&) = default; + SpvModule &operator=(SpvModule &&) = default; + + void add_debug(const SpvInstruction &val); + void add_annotation(const SpvInstruction &val); + void add_type(const SpvInstruction &val); + void add_constant(const SpvInstruction &val); + void add_global(const SpvInstruction &val); + void add_execution_mode(const SpvInstruction &val); + void add_function(SpvFunction val); + void add_instruction(const SpvInstruction &val); + void add_entry_point(const std::string &name, SpvInstruction entry_point); + + void require_capability(SpvCapability val); + void require_extension(const std::string &val); + + void set_source_language(SpvSourceLanguage val); + void set_addressing_model(SpvAddressingModel val); + void set_memory_model(SpvMemoryModel val); + SpvSourceLanguage source_language() const; + SpvAddressingModel addressing_model() const; + SpvMemoryModel memory_model() const; + SpvInstruction entry_point(const std::string &name) const; + EntryPointNames entry_point_names() const; + const Instructions &execution_modes() const; + SpvModule module() const; + + bool is_capability_required(SpvCapability val) const; + bool is_extension_required(const std::string &val) const; + bool is_defined() const; + SpvId id() const; + void check_defined() const; + + void encode(SpvBinary &binary) const; + + static SpvModule make(SpvId module_id, + SpvSourceLanguage source_language = SpvSourceLanguageUnknown, + SpvAddressingModel addressing_model = SpvAddressingModelLogical, + SpvMemoryModel memory_model = SpvMemoryModelSimple); + +protected: + SpvModuleContentsPtr contents; +}; + +/** Builder interface for constructing a SPIR-V code module and + * all associated types, declarations, blocks, functions & + * instructions */ +class SpvBuilder { +public: + using ParamTypes = std::vector; + using StructMemberTypes = std::vector; + using Variables = std::vector; + using Indices = std::vector; + using Literals = std::vector; + + SpvBuilder(); + ~SpvBuilder() = default; + + SpvBuilder(const SpvBuilder &) = delete; + SpvBuilder &operator=(const SpvBuilder &) = delete; + + SpvId reserve_id(SpvKind = SpvResultId); + SpvKind kind_of(SpvId id); + + SpvId map_type(const Type &type, uint32_t array_size = 1); + SpvId map_pointer_type(const Type &type, SpvStorageClass storage_class); + SpvId map_pointer_type(SpvId type_id, SpvStorageClass storage_class); + SpvId map_constant(const Type &type, const void *data); + SpvId map_null_constant(const Type &type); + SpvId map_bool_constant(bool value); + SpvId map_function_type(SpvId return_type, const ParamTypes ¶m_types = {}); + + SpvId declare_type(const Type &type, uint32_t array_size = 1); + SpvId declare_struct(const StructMemberTypes &member_types); + SpvId declare_runtime_array(SpvId base_type_id); + SpvId declare_pointer_type(const Type &type, SpvStorageClass storage_class); + SpvId declare_pointer_type(SpvId base_type_id, SpvStorageClass storage_class); + SpvId declare_constant(const Type &type, const void *data); + SpvId declare_null_constant(const Type &type); + SpvId declare_bool_constant(bool value); + SpvId declare_string_constant(const std::string &str); + SpvId declare_scalar_constant(const Type &type, const void *data); + SpvId declare_vector_constant(const Type &type, const void *data); + SpvId declare_access_chain(SpvId ptr_type_id, SpvId base_id, SpvId element_id, const Indices &indices); + SpvId declare_function_type(SpvId return_type_id, const ParamTypes ¶m_type_ids); + + SpvFunction add_function(SpvId return_type, const ParamTypes ¶m_types = {}); + SpvId add_instruction(SpvInstruction val); + void add_annotation(SpvId target_id, SpvDecoration decoration_type, const Literals &literals = {}); + void add_struct_annotation(SpvId struct_type_id, uint32_t member_index, SpvDecoration decoration_type, const Literals &literals = {}); + + SpvId add_variable(SpvId type_id, uint32_t storage_class, SpvId initializer_id = SpvInvalidId); + SpvId add_global_variable(SpvId type_id, uint32_t storage_class, SpvId initializer_id = SpvInvalidId); + + SpvId map_struct(const StructMemberTypes &member_types); + + void add_entry_point(const std::string &name, + SpvId func_id, SpvExecutionModel exec_model, + const Variables &variables = {}); + + void add_execution_mode_local_size(SpvId entry_point_id, uint32_t wg_size_x, uint32_t wg_size_y, uint32_t wg_size_z); + + void set_source_language(SpvSourceLanguage val); + void set_addressing_model(SpvAddressingModel val); + void set_memory_model(SpvMemoryModel val); + + SpvSourceLanguage source_language() const; + SpvAddressingModel addressing_model() const; + SpvMemoryModel memory_model() const; + + void enter_block(const SpvBlock &block); + SpvBlock current_block() const; + SpvBlock leave_block(); + + void enter_function(const SpvFunction &func); + SpvFunction lookup_function(SpvId func_id) const; + SpvFunction current_function() const; + SpvFunction leave_function(); + + void set_current_id(SpvId id); + SpvId current_id() const; + + SpvModule current_module() const; + + void require_extension(const std::string &extension); + void require_capability(SpvCapability); + + bool is_extension_required(const std::string &extension) const; + bool is_capability_required(SpvCapability) const; + + void append(SpvInstruction inst); + void encode(SpvBinary &binary) const; + +protected: + using TypeKey = std::string; + using TypeMap = std::unordered_map; + using KindMap = std::unordered_map; + using PointerTypeKey = std::pair; + using PointerTypeMap = std::map; + using ConstantKey = std::string; + using ConstantMap = std::unordered_map; + using StringMap = std::unordered_map; + using InstructionMap = std::unordered_map; + using FunctionTypeKey = std::string; + using FunctionTypeMap = std::unordered_map; + using FunctionMap = std::unordered_map; + using FunctionStack = std::stack; + using BlockStack = std::stack; + + SpvId declare_id(SpvKind kind); + + TypeKey make_type_key(const Type &type, uint32_t array_size = 1) const; + SpvId lookup_type(const Type &type, uint32_t array_size = 1) const; + + TypeKey make_struct_type_key(const StructMemberTypes &member_types) const; + SpvId lookup_struct(const StructMemberTypes &member_types) const; + + PointerTypeKey make_pointer_type_key(const Type &type, SpvStorageClass storage_class) const; + SpvId lookup_pointer_type(const Type &type, SpvStorageClass storage_class) const; + + PointerTypeKey make_pointer_type_key(SpvId base_type_id, SpvStorageClass storage_class) const; + SpvId lookup_pointer_type(SpvId base_type_id, SpvStorageClass storage_class) const; + + ConstantKey make_bool_constant_key(bool value) const; + + ConstantKey make_constant_key(const Type &type, const void *data) const; + SpvId lookup_constant(const Type &type, const void *data) const; + + ConstantKey make_null_constant_key(const Type &type) const; + SpvId lookup_null_constant(const Type &type) const; + + SpvId map_instruction(const SpvInstruction &inst); + SpvInstruction lookup_instruction(SpvId result_id) const; + bool has_instruction(SpvId inst) const; + + FunctionTypeKey make_function_type_key(SpvId return_type_id, const ParamTypes ¶m_type_ids) const; + SpvId lookup_function_type(SpvId return_type_id, const ParamTypes ¶m_type_ids) const; + + SpvId scope_id = SpvInvalidId; + SpvModule module; + KindMap kind_map; + TypeMap type_map; + TypeMap struct_map; + StringMap string_map; + ConstantMap constant_map; + FunctionMap function_map; + InstructionMap instruction_map; + PointerTypeMap pointer_type_map; + FunctionTypeMap function_type_map; + FunctionStack function_stack; + BlockStack block_stack; +}; + +/** Factory interface for constructing specific SPIR-V instructions */ +struct SpvFactory { + using Indices = std::vector; + using Literals = std::vector; + using BranchWeights = std::vector; + using Components = std::vector; + using ParamTypes = std::vector; + using MemberTypeIds = std::vector; + using Variables = std::vector; + using VariableBlockIdPair = std::pair; // (Variable Id, Block Id) + using BlockVariables = std::vector; + + static SpvInstruction capability(const SpvCapability &capability); + static SpvInstruction extension(const std::string &extension); + static SpvInstruction import(const std::string &import); + static SpvInstruction label(SpvId result_id); + static SpvInstruction decorate(SpvId target_id, SpvDecoration decoration_type, const Literals &literals = {}); + static SpvInstruction decorate_member(SpvId struct_type_id, uint32_t member_index, SpvDecoration decoration_type, const Literals &literals = {}); + static SpvInstruction void_type(SpvId void_type_id); + static SpvInstruction bool_type(SpvId bool_type_id); + static SpvInstruction integer_type(SpvId int_type_id, uint32_t bits, uint32_t signedness); + static SpvInstruction float_type(SpvId float_type_id, uint32_t bits); + static SpvInstruction vector_type(SpvId vector_type_id, SpvId element_type_id, uint32_t vector_size); + static SpvInstruction array_type(SpvId array_type_id, SpvId element_type_id, uint32_t array_size); + static SpvInstruction struct_type(SpvId result_id, const MemberTypeIds &member_type_ids); + static SpvInstruction runtime_array_type(SpvId result_type_id, SpvId base_type_id); + static SpvInstruction pointer_type(SpvId pointer_type_id, SpvStorageClass storage_class, SpvId base_type_id); + static SpvInstruction function_type(SpvId function_type_id, SpvId return_type_id, const ParamTypes ¶m_type_ids); + static SpvInstruction constant(SpvId result_id, SpvId type_id, size_t bytes, const void *data); + static SpvInstruction null_constant(SpvId result_id, SpvId type_id); + static SpvInstruction bool_constant(SpvId result_id, SpvId type_id, bool value); + static SpvInstruction composite_constant(SpvId result_id, SpvId type_id, const Components &components); + static SpvInstruction variable(SpvId result_id, SpvId result_type_id, uint32_t storage_class, SpvId initializer_id = SpvInvalidId); + static SpvInstruction function(SpvId return_type_id, SpvId func_id, uint32_t control_mask, SpvId func_type_id); + static SpvInstruction function_parameter(SpvId param_type_id, SpvId param_id); + static SpvInstruction function_end(); + static SpvInstruction return_stmt(SpvId return_value_id = SpvInvalidId); + static SpvInstruction entry_point(SpvId exec_model, SpvId func_id, const std::string &name, const Variables &variables); + static SpvInstruction memory_model(SpvAddressingModel addressing_model, SpvMemoryModel memory_model); + static SpvInstruction exec_mode_local_size(SpvId function_id, uint32_t wg_size_x, uint32_t wg_size_y, uint32_t wg_size_z); + static SpvInstruction control_barrier(SpvId execution_scope_id, SpvId memory_scope_id, uint32_t semantics_mask); + static SpvInstruction logical_not(SpvId type_id, SpvId result_id, SpvId src_id); + static SpvInstruction multiply_extended(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, bool is_signed); + static SpvInstruction select(SpvId type_id, SpvId result_id, SpvId condition_id, SpvId true_id, SpvId false_id); + static SpvInstruction in_bounds_access_chain(SpvId type_id, SpvId result_id, SpvId base_id, SpvId element_id, const Indices &indices); + static SpvInstruction load(SpvId type_id, SpvId result_id, SpvId ptr_id, uint32_t access_mask = 0x0); + static SpvInstruction store(SpvId ptr_id, SpvId obj_id, uint32_t access_mask = 0x0); + static SpvInstruction vector_insert_dynamic(SpvId result_id, SpvId vector_id, SpvId value_id, uint32_t index); + static SpvInstruction composite_extract(SpvId type_id, SpvId result_id, SpvId composite_id, const Indices &indices); + static SpvInstruction bitcast(SpvId type_id, SpvId result_id, SpvId src_id); + static SpvInstruction integer_add(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id); + static SpvInstruction branch(SpvId target_label_id); + static SpvInstruction conditional_branch(SpvId condition_label_id, SpvId true_label_id, SpvId false_label_id, const BranchWeights &weights = {}); + static SpvInstruction loop_merge(SpvId merge_label_id, SpvId continue_label_id, uint32_t loop_control_mask = SpvLoopControlMaskNone); + static SpvInstruction selection_merge(SpvId merge_label_id, uint32_t selection_control_mask = SpvSelectionControlMaskNone); + static SpvInstruction phi(SpvId type_id, SpvId result_id, const BlockVariables &block_vars); + static SpvInstruction unary_op(SpvOp op_code, SpvId type_id, SpvId result_id, SpvId src_id); + static SpvInstruction binary_op(SpvOp op_code, SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id); + static SpvInstruction convert(SpvOp op_code, SpvId type_id, SpvId result_id, SpvId src_id); +}; + +/** Contents of a SPIR-V Instruction */ +struct SpvInstructionContents { + using Operands = std::vector; + using Immediates = std::vector; + mutable RefCount ref_count; + SpvOp op_code = SpvOpNop; + SpvId result_id = SpvNoResult; + SpvId type_id = SpvNoType; + Operands operands; + Immediates immediates; + SpvBlock block; +}; + +/** Contents of a SPIR-V code block */ +struct SpvBlockContents { + using Instructions = std::vector; + using Variables = std::vector; + using Blocks = std::vector; + mutable RefCount ref_count; + SpvId block_id = SpvInvalidId; + SpvFunction parent; + Instructions instructions; + Variables variables; + Blocks before; + Blocks after; + bool reachable = true; +}; + +/** Contents of a SPIR-V function */ +struct SpvFunctionContents { + using PrecisionMap = std::unordered_map; + using Parameters = std::vector; + using Blocks = std::vector; + mutable RefCount ref_count; + SpvModule parent; + SpvId function_id; + SpvId function_type_id; + SpvId return_type_id; + uint32_t control_mask; + SpvInstruction declaration; + Parameters parameters; + PrecisionMap precision; + Blocks blocks; +}; + +/** Contents of a SPIR-V code module */ +struct SpvModuleContents { + using Capabilities = std::set; + using Extensions = std::set; + using Imports = std::set; + using Functions = std::vector; + using Instructions = std::vector; + using EntryPoints = std::unordered_map; + + mutable RefCount ref_count; + SpvId module_id = SpvInvalidId; + SpvSourceLanguage source_language = SpvSourceLanguageUnknown; + SpvAddressingModel addressing_model = SpvAddressingModelLogical; + SpvMemoryModel memory_model = SpvMemoryModelSimple; + Capabilities capabilities; + Extensions extensions; + Imports imports; + EntryPoints entry_points; + Instructions execution_modes; + Instructions debug; + Instructions annotations; + Instructions types; + Instructions constants; + Instructions globals; + Functions functions; + Instructions instructions; +}; + +} // namespace Internal +} // namespace Halide + +#endif // WITH_SPIRV + +namespace Halide { +namespace Internal { + +/** Internal test for SPIR-V IR **/ +void spirv_ir_test(); + +} // namespace Internal +} // namespace Halide + +#endif // HALIDE_SPIRV_IR_H diff --git a/src/SplitTuples.cpp b/src/SplitTuples.cpp index 686f919a25af..b2de80d12e81 100644 --- a/src/SplitTuples.cpp +++ b/src/SplitTuples.cpp @@ -69,10 +69,11 @@ class SplitTuples : public IRMutator { if (op->types.size() > 1) { // Make a nested set of realize nodes for each tuple element Stmt body = mutate(op->body); + Expr condition = mutate(op->condition); for (int i = (int)op->types.size() - 1; i >= 0; i--) { body = Realize::make(op->name + "." + std::to_string(i), {op->types[i]}, op->memory_type, - op->bounds, op->condition, body); + op->bounds, condition, body); } return body; } else { diff --git a/src/StmtToHtml.cpp b/src/StmtToHtml.cpp index a77183e31d11..21bc74dd20ac 100644 --- a/src/StmtToHtml.cpp +++ b/src/StmtToHtml.cpp @@ -222,6 +222,19 @@ class StmtToHtml : public IRVisitor { stream << close_span(); } + void visit(const Reinterpret *op) override { + stream << open_span("Reinterpret"); + + stream << open_span("Matched"); + stream << open_span("Type") << op->type << close_span(); + stream << "("; + stream << close_span(); + print(op->value); + stream << matched(")"); + + stream << close_span(); + } + void visit_binary_op(const Expr &a, const Expr &b, const char *op) { stream << open_span("BinaryOp"); diff --git a/src/Target.cpp b/src/Target.cpp index cd9c08fe9448..7c6dbc64f5e0 100644 --- a/src/Target.cpp +++ b/src/Target.cpp @@ -62,6 +62,97 @@ static void cpuid(int info[4], int infoType, int extra) { #endif #endif +#if defined(__x86_64__) || defined(__i386__) || defined(_MSC_VER) + +enum class VendorSignatures { + Unknown, + GenuineIntel, + AuthenticAMD, +}; + +VendorSignatures get_vendor_signature() { + int info[4]; + cpuid(info, 0, 0); + + if (info[0] < 1) { + return VendorSignatures::Unknown; + } + + // "Genu ineI ntel" + if (info[1] == 0x756e6547 && info[3] == 0x49656e69 && info[2] == 0x6c65746e) { + return VendorSignatures::GenuineIntel; + } + + // "Auth enti cAMD" + if (info[1] == 0x68747541 && info[3] == 0x69746e65 && info[2] == 0x444d4163) { + return VendorSignatures::AuthenticAMD; + } + + return VendorSignatures::Unknown; +} + +void detect_family_and_model(int info0, unsigned &family, unsigned &model) { + family = (info0 >> 8) & 0xF; // Bits 8..11 + model = (info0 >> 4) & 0xF; // Bits 4..7 + if (family == 0x6 || family == 0xF) { + if (family == 0xF) { + // Examine extended family ID if family ID is 0xF. + family += (info0 >> 20) & 0xFf; // Bits 20..27 + } + // Examine extended model ID if family ID is 0x6 or 0xF. + model += ((info0 >> 16) & 0xF) << 4; // Bits 16..19 + } +} + +Target::Processor get_amd_processor(unsigned family, unsigned model, bool have_sse3) { + switch (family) { + case 0xF: // AMD Family 0Fh + if (have_sse3) { + return Target::Processor::K8_SSE3; // Hammer (modern, with SSE3) + } + return Target::Processor::K8; // Hammer (original, without SSE3) + case 0x10: // AMD Family 10h + return Target::Processor::AMDFam10; // Barcelona + case 0x14: // AMD Family 14h + return Target::Processor::BtVer1; // Bobcat + case 0x15: // AMD Family 15h + if (model >= 0x60 && model <= 0x7f) { + return Target::Processor::BdVer4; // 60h-7Fh: Excavator + } + if (model >= 0x30 && model <= 0x3f) { + return Target::Processor::BdVer3; // 30h-3Fh: Steamroller + } + if ((model >= 0x10 && model <= 0x1f) || model == 0x02) { + return Target::Processor::BdVer2; // 02h, 10h-1Fh: Piledriver + } + if (model <= 0x0f) { + return Target::Processor::BdVer1; // 00h-0Fh: Bulldozer + } + break; + case 0x16: // AMD Family 16h + return Target::Processor::BtVer2; // Jaguar + case 0x17: // AMD Family 17h + if ((model >= 0x30 && model <= 0x3f) || model == 0x71) { + return Target::Processor::ZnVer2; // 30h-3Fh, 71h: Zen2 + } + if (model <= 0x0f) { + return Target::Processor::ZnVer1; // 00h-0Fh: Zen1 + } + break; + case 0x19: // AMD Family 19h + if (model <= 0x0f || model == 0x21) { + return Target::Processor::ZnVer3; // 00h-0Fh, 21h: Zen3 + } + break; + default: + break; // Unknown AMD CPU. + } + + return Target::Processor::ProcessorGeneric; +} + +#endif // defined(__x86_64__) || defined(__i386__) || defined(_MSC_VER) + Target calculate_host_target() { Target::OS os = Target::OSUnknown; #ifdef __linux__ @@ -76,6 +167,8 @@ Target calculate_host_target() { bool use_64_bits = (sizeof(size_t) == 8); int bits = use_64_bits ? 64 : 32; + int vector_bits = 0; + Target::Processor processor = Target::Processor::ProcessorGeneric; std::vector initial_features; #if __riscv @@ -110,14 +203,21 @@ Target calculate_host_target() { #else Target::Arch arch = Target::X86; + VendorSignatures vendor_signature = get_vendor_signature(); + int info[4]; cpuid(info, 1, 0); - bool have_sse41 = (info[2] & (1 << 19)) != 0; - bool have_sse2 = (info[3] & (1 << 26)) != 0; - bool have_avx = (info[2] & (1 << 28)) != 0; - bool have_f16c = (info[2] & (1 << 29)) != 0; - bool have_rdrand = (info[2] & (1 << 30)) != 0; - bool have_fma = (info[2] & (1 << 12)) != 0; + + unsigned family = 0, model = 0; + detect_family_and_model(info[0], family, model); + + bool have_sse41 = (info[2] & (1 << 19)) != 0; // ECX[19] + bool have_sse2 = (info[3] & (1 << 26)) != 0; // EDX[26] + bool have_sse3 = (info[2] & (1 << 0)) != 0; // ECX[0] + bool have_avx = (info[2] & (1 << 28)) != 0; // ECX[28] + bool have_f16c = (info[2] & (1 << 29)) != 0; // ECX[29] + bool have_rdrand = (info[2] & (1 << 30)) != 0; // ECX[30] + bool have_fma = (info[2] & (1 << 12)) != 0; // ECX[12] user_assert(have_sse2) << "The x86 backend assumes at least sse2 support. This machine does not appear to have sse2.\n" @@ -128,6 +228,10 @@ Target calculate_host_target() { << ", " << info[3] << std::dec << "\n"; + if (vendor_signature == VendorSignatures::AuthenticAMD) { + processor = get_amd_processor(family, model, have_sse3); + } + if (have_sse41) { initial_features.push_back(Target::SSE41); } @@ -164,12 +268,15 @@ Target calculate_host_target() { } if ((info2[1] & avx512) == avx512) { initial_features.push_back(Target::AVX512); + // TODO: port to family/model -based detection. if ((info2[1] & avx512_knl) == avx512_knl) { initial_features.push_back(Target::AVX512_KNL); } + // TODO: port to family/model -based detection. if ((info2[1] & avx512_skylake) == avx512_skylake) { initial_features.push_back(Target::AVX512_Skylake); } + // TODO: port to family/model -based detection. if ((info2[1] & avx512_cannonlake) == avx512_cannonlake) { initial_features.push_back(Target::AVX512_Cannonlake); @@ -177,6 +284,7 @@ Target calculate_host_target() { const uint32_t avx512bf16 = 1U << 5; // bf16 result in eax, with cpuid(eax=7, ecx=1) int info3[4]; cpuid(info3, 7, 1); + // TODO: port to family/model -based detection. if ((info2[2] & avx512vnni) == avx512vnni && (info3[0] & avx512bf16) == avx512bf16) { initial_features.push_back(Target::AVX512_SapphireRapids); @@ -189,7 +297,7 @@ Target calculate_host_target() { #endif #endif - return {os, arch, bits, initial_features}; + return {os, arch, bits, processor, initial_features, vector_bits}; } bool is_using_hexagon(const Target &t) { @@ -254,7 +362,7 @@ Target::Feature calculate_host_cuda_capability(Target t) { return Target::CUDACapability70; } else if (ver < 80) { return Target::CUDACapability75; - } else if (ver < 86 || LLVM_VERSION < 130) { + } else if (ver < 86) { return Target::CUDACapability80; } else { return Target::CUDACapability86; @@ -307,6 +415,38 @@ bool lookup_arch(const std::string &tok, Target::Arch &result) { return false; } +/// Important design consideration: currently, the string key is +/// effectively identical to the LLVM CPU string, and it would be really really +/// good to keep it that way, so the proper tune_* can be autogenerated easily +/// from the LLVM CPU string (currently, by replacing "-" with "_", +/// and prepending "tune_" prefix) +/// +/// Please keep sorted. +const std::map processor_name_map = { + {"tune_amdfam10", Target::Processor::AMDFam10}, + {"tune_bdver1", Target::Processor::BdVer1}, + {"tune_bdver2", Target::Processor::BdVer2}, + {"tune_bdver3", Target::Processor::BdVer3}, + {"tune_bdver4", Target::Processor::BdVer4}, + {"tune_btver1", Target::Processor::BtVer1}, + {"tune_btver2", Target::Processor::BtVer2}, + {"tune_generic", Target::Processor::ProcessorGeneric}, + {"tune_k8", Target::Processor::K8}, + {"tune_k8_sse3", Target::Processor::K8_SSE3}, + {"tune_znver1", Target::Processor::ZnVer1}, + {"tune_znver2", Target::Processor::ZnVer2}, + {"tune_znver3", Target::Processor::ZnVer3}, +}; + +bool lookup_processor(const std::string &tok, Target::Processor &result) { + auto processor_iter = processor_name_map.find(tok); + if (processor_iter != processor_name_map.end()) { + result = processor_iter->second; + return true; + } + return false; +} + const std::map feature_name_map = { {"jit", Target::JIT}, {"debug", Target::Debug}, @@ -339,7 +479,6 @@ const std::map feature_name_map = { {"openglcompute", Target::OpenGLCompute}, {"egl", Target::EGL}, {"user_context", Target::UserContext}, - {"matlab", Target::Matlab}, {"profile", Target::Profile}, {"no_runtime", Target::NoRuntime}, {"metal", Target::Metal}, @@ -370,6 +509,9 @@ const std::map feature_name_map = { {"check_unsafe_promises", Target::CheckUnsafePromises}, {"hexagon_dma", Target::HexagonDma}, {"embed_bitcode", Target::EmbedBitcode}, + // halide_target_feature_disable_llvm_loop_opt is deprecated in Halide 15 + // (and will be removed in Halide 16). Halide 15 now defaults to disabling + // LLVM loop optimization, unless halide_target_feature_enable_llvm_loop_opt is set. {"disable_llvm_loop_opt", Target::DisableLLVMLoopOpt}, {"enable_llvm_loop_opt", Target::EnableLLVMLoopOpt}, {"wasm_simd128", Target::WasmSimd128}, @@ -386,6 +528,7 @@ const std::map feature_name_map = { {"armv81a", Target::ARMv81a}, {"sanitizer_coverage", Target::SanitizerCoverage}, {"profile_by_timer", Target::ProfileByTimer}, + {"spirv", Target::SPIRV}, // NOTE: When adding features to this map, be sure to update PyEnums.cpp as well. }; @@ -398,6 +541,18 @@ bool lookup_feature(const std::string &tok, Target::Feature &result) { return false; } +int parse_vector_bits(const std::string &tok) { + if (tok.find("vector_bits_") == 0) { + std::string num = tok.substr(sizeof("vector_bits_") - 1, std::string::npos); + size_t end_index; + int parsed = std::stoi(num, &end_index); + if (end_index == num.size()) { + return parsed; + } + } + return -1; +} + } // End anonymous namespace Target get_target_from_environment() { @@ -412,19 +567,18 @@ Target get_target_from_environment() { Target get_jit_target_from_environment() { Target host = get_host_target(); host.set_feature(Target::JIT); -#if defined(__has_feature) -#if __has_feature(address_sanitizer) +// Note, we must include Util.h for these to be defined properly (or not) +#ifdef HALIDE_INTERNAL_USING_ASAN host.set_feature(Target::ASAN); #endif -#if __has_feature(memory_sanitizer) +#ifdef HALIDE_INTERNAL_USING_MSAN host.set_feature(Target::MSAN); #endif -#if __has_feature(thread_sanitizer) +#ifdef HALIDE_INTERNAL_USING_TSAN host.set_feature(Target::TSAN); #endif -#if __has_feature(coverage_sanitizer) +#ifdef HALIDE_INTERNAL_USING_COVSAN host.set_feature(Target::SanitizerCoverage); -#endif #endif string target = Internal::get_env_variable("HL_JIT_TARGET"); if (target.empty()) { @@ -454,12 +608,13 @@ bool merge_string(Target &t, const std::string &target) { } tokens.push_back(rest); - bool os_specified = false, arch_specified = false, bits_specified = false, features_specified = false; + bool os_specified = false, arch_specified = false, bits_specified = false, processor_specified = false, features_specified = false; bool is_host = false; for (size_t i = 0; i < tokens.size(); i++) { const string &tok = tokens[i]; Target::Feature feature; + int vector_bits; if (tok == "host") { if (i > 0) { @@ -484,12 +639,19 @@ bool merge_string(Target &t, const std::string &target) { return false; } os_specified = true; + } else if (lookup_processor(tok, t.processor_tune)) { + if (processor_specified) { + return false; + } + processor_specified = true; } else if (lookup_feature(tok, feature)) { t.set_feature(feature); features_specified = true; } else if (tok == "trace_all") { t.set_features({Target::TraceLoads, Target::TraceStores, Target::TraceRealizations}); features_specified = true; + } else if ((vector_bits = parse_vector_bits(tok)) >= 0) { + t.vector_bits = vector_bits; } else { return false; } @@ -541,6 +703,12 @@ void bad_target_string(const std::string &target) { separator = ", "; } separator = ""; + std::string processors; + for (const auto &processor_entry : processor_name_map) { + processors += separator + processor_entry.first; + separator = ", "; + } + separator = ""; // Format the features to go one feature over 70 characters per line, // assume the first line starts with "Features are ". int line_char_start = -(int)sizeof("Features are"); @@ -555,13 +723,16 @@ void bad_target_string(const std::string &target) { } } user_error << "Did not understand Halide target " << target << "\n" - << "Expected format is arch-bits-os-feature1-feature2-...\n" + << "Expected format is arch-bits-os-processor-feature1-feature2-...\n" << "Where arch is: " << architectures << ".\n" << "bits is either 32 or 64.\n" << "os is: " << oses << ".\n" + << "processor is: " << processors << ".\n" << "\n" << "If arch, bits, or os are omitted, they default to the host.\n" << "\n" + << "If processor is omitted, it defaults to tune_generic.\n" + << "\n" << "Features are: " << features << ".\n" << "\n" << "The target can also begin with \"host\", which sets the " @@ -628,6 +799,14 @@ std::string Target::to_string() const { break; } } + if (processor_tune != ProcessorGeneric) { + for (const auto &processor_entry : processor_name_map) { + if (processor_entry.second == processor_tune) { + result += "-" + processor_entry.first; + break; + } + } + } for (const auto &feature_entry : feature_name_map) { if (has_feature(feature_entry.second)) { result += "-" + feature_entry.first; @@ -638,6 +817,10 @@ std::string Target::to_string() const { if (has_feature(Target::TraceLoads) && has_feature(Target::TraceStores) && has_feature(Target::TraceRealizations)) { result = Internal::replace_all(result, "trace_loads-trace_realizations-trace_stores", "trace_all"); } + if (vector_bits != 0) { + result += "-vector_bits_" + std::to_string(vector_bits); + } + return result; } @@ -898,7 +1081,15 @@ int Target::natural_vector_size(const Halide::Type &t) const { const bool is_integer = t.is_int() || t.is_uint(); const int data_size = t.bytes(); - if (arch == Target::Hexagon) { + if (arch == Target::ARM) { + if (vector_bits != 0 && + (has_feature(Halide::Target::SVE2) || + (t.is_float() && has_feature(Halide::Target::SVE)))) { + return vector_bits / (data_size * 8); + } else { + return 16 / data_size; + } + } else if (arch == Target::Hexagon) { if (is_integer) { if (has_feature(Halide::Target::HVX)) { return 128 / data_size; @@ -940,6 +1131,13 @@ int Target::natural_vector_size(const Halide::Type &t) const { // No vectors, sorry. return 1; } + } else if (arch == Target::RISCV) { + if (vector_bits != 0 && + has_feature(Halide::Target::RVV)) { + return vector_bits / (data_size * 8); + } else { + return 1; + } } else { // Assume 128-bit vectors on other targets. return 16 / data_size; @@ -1047,7 +1245,7 @@ bool Target::get_runtime_compatible_target(const Target &other, Target &result) // Union of features is computed through bitwise-or, and masked away by the features we care about // Intersection of features is computed through bitwise-and and masked away, too. // We merge the bits via bitwise or. - Target output = Target{os, arch, bits}; + Target output = Target{os, arch, bits, processor_tune}; output.features = ((features | other.features) & union_mask) | ((features | other.features) & matching_mask) | ((features & other.features) & intersection_mask); // Pick tight lower bound for CUDA capability. Use fall-through to clear redundant features @@ -1147,6 +1345,12 @@ void target_test() { } } + internal_assert(Target().vector_bits == 0) << "Default Target vector_bits not 0.\n"; + internal_assert(Target("arm-64-linux-sve2-vector_bits_512").vector_bits == 512) << "Vector bits not parsed correctly.\n"; + Target with_vector_bits(Target::Linux, Target::ARM, 64, Target::ProcessorGeneric, {Target::SVE}, 512); + internal_assert(with_vector_bits.vector_bits == 512) << "Vector bits not populated in constructor.\n"; + internal_assert(Target(with_vector_bits.to_string()).vector_bits == 512) << "Vector bits not round tripped properly.\n"; + std::cout << "Target test passed" << std::endl; } diff --git a/src/Target.h b/src/Target.h index 5b9588ab60d9..8678bfefbb90 100644 --- a/src/Target.h +++ b/src/Target.h @@ -50,6 +50,32 @@ struct Target { /** The bit-width of the target machine. Must be 0 for unknown, or 32 or 64. */ int bits = 0; + /** The bit-width of a vector register for targets where this is configurable and + * targeting a fixed size is desired. The default of 0 indicates no assumption of + * fixed size is allowed. */ + int vector_bits = 0; + + /** The specific processor to be targeted, tuned for. + * Corresponds to processor_name_map in Target.cpp. + * + * New entries should be added to the end. */ + enum Processor { + /// Do not tune for any specific CPU. In practice, this means that halide will decide the tune CPU based on the enabled features. + ProcessorGeneric = 0, + K8, /// Tune for AMD K8 Hammer CPU (AMD Family 0Fh, launched 2003). + K8_SSE3, /// Tune for later versions of AMD K8 CPU, with SSE3 support. + AMDFam10, /// Tune for AMD K10 "Barcelona" CPU (AMD Family 10h, launched 2007). + BtVer1, /// Tune for AMD Bobcat CPU (AMD Family 14h, launched 2011). + BdVer1, /// Tune for AMD Bulldozer CPU (AMD Family 15h, launched 2011). + BdVer2, /// Tune for AMD Piledriver CPU (AMD Family 15h (2nd-gen), launched 2012). + BdVer3, /// Tune for AMD Steamroller CPU (AMD Family 15h (3nd-gen), launched 2014). + BdVer4, /// Tune for AMD Excavator CPU (AMD Family 15h (4th-gen), launched 2015). + BtVer2, /// Tune for AMD Jaguar CPU (AMD Family 16h, launched 2013). + ZnVer1, /// Tune for AMD Zen CPU (AMD Family 17h, launched 2017). + ZnVer2, /// Tune for AMD Zen 2 CPU (AMD Family 17h, launched 2019). + ZnVer3, /// Tune for AMD Zen 3 CPU (AMD Family 19h, launched 2020). + } processor_tune = ProcessorGeneric; + /** Optional features a target can have. * Corresponds to feature_name_map in Target.cpp. * See definitions in HalideRuntime.h for full information. @@ -86,7 +112,6 @@ struct Target { OpenGLCompute = halide_target_feature_openglcompute, EGL = halide_target_feature_egl, UserContext = halide_target_feature_user_context, - Matlab = halide_target_feature_matlab, Profile = halide_target_feature_profile, NoRuntime = halide_target_feature_no_runtime, Metal = halide_target_feature_metal, @@ -118,6 +143,9 @@ struct Target { CheckUnsafePromises = halide_target_feature_check_unsafe_promises, EmbedBitcode = halide_target_feature_embed_bitcode, EnableLLVMLoopOpt = halide_target_feature_enable_llvm_loop_opt, + // halide_target_feature_disable_llvm_loop_opt is deprecated in Halide 15 + // (and will be removed in Halide 16). Halide 15 now defaults to disabling + // LLVM loop optimization, unless halide_target_feature_enable_llvm_loop_opt is set. DisableLLVMLoopOpt = halide_target_feature_disable_llvm_loop_opt, WasmSimd128 = halide_target_feature_wasm_simd128, WasmSignExt = halide_target_feature_wasm_signext, @@ -133,16 +161,22 @@ struct Target { ARMv81a = halide_target_feature_armv81a, SanitizerCoverage = halide_target_feature_sanitizer_coverage, ProfileByTimer = halide_target_feature_profile_by_timer, + SPIRV = halide_target_feature_spirv, FeatureEnd = halide_target_feature_end }; Target() = default; - Target(OS o, Arch a, int b, const std::vector &initial_features = std::vector()) - : os(o), arch(a), bits(b) { + Target(OS o, Arch a, int b, Processor pt, const std::vector &initial_features = std::vector(), + int vb = 0) + : os(o), arch(a), bits(b), vector_bits(vb), processor_tune(pt) { for (const auto &f : initial_features) { set_feature(f); } } + Target(OS o, Arch a, int b, const std::vector &initial_features = std::vector()) + : Target(o, a, b, ProcessorGeneric, initial_features) { + } + /** Given a string of the form used in HL_TARGET * (e.g. "x86-64-avx"), construct the Target it specifies. Note * that this always starts with the result of get_host_target(), @@ -226,6 +260,7 @@ struct Target { return os == other.os && arch == other.arch && bits == other.bits && + processor_tune == other.processor_tune && features == other.features; } @@ -247,7 +282,7 @@ struct Target { /** Convert the Target into a string form that can be reconstituted * by merge_string(), which will always be of the form * - * arch-bits-os-feature1-feature2...featureN. + * arch-bits-os-processor-feature1-feature2...featureN. * * Note that is guaranteed that Target(t1.to_string()) == t1, * but not that Target(s).to_string() == s (since there can be diff --git a/src/ThreadPool.h b/src/ThreadPool.h index 287acf29d88b..796a15fada1c 100644 --- a/src/ThreadPool.h +++ b/src/ThreadPool.h @@ -2,6 +2,7 @@ #define HALIDE_THREAD_POOL_H #include +#include #include #include #include diff --git a/src/Type.cpp b/src/Type.cpp index 0e575b328b74..64414fa04eca 100644 --- a/src/Type.cpp +++ b/src/Type.cpp @@ -298,7 +298,8 @@ std::string type_to_c_type(Type type, bool include_space, bool c_plus_plus) { if (modifier & halide_handle_cplusplus_type::Restrict) { oss << " restrict"; } - if (modifier & halide_handle_cplusplus_type::Pointer) { + if ((modifier & halide_handle_cplusplus_type::Pointer) && + !(modifier & halide_handle_cplusplus_type::FunctionTypedef)) { oss << " *"; } } diff --git a/src/Type.h b/src/Type.h index f020a4ae2625..63e65bd7f771 100644 --- a/src/Type.h +++ b/src/Type.h @@ -84,10 +84,11 @@ struct halide_handle_cplusplus_type { /// One set of modifiers on a type. /// The const/volatile/restrict properties are "inside" the pointer property. enum Modifier : uint8_t { - Const = 1 << 0, ///< Bitmask flag for "const" - Volatile = 1 << 1, ///< Bitmask flag for "volatile" - Restrict = 1 << 2, ///< Bitmask flag for "restrict" - Pointer = 1 << 3, ///< Bitmask flag for a pointer "*" + Const = 1 << 0, ///< Bitmask flag for "const" + Volatile = 1 << 1, ///< Bitmask flag for "volatile" + Restrict = 1 << 2, ///< Bitmask flag for "restrict" + Pointer = 1 << 3, ///< Bitmask flag for a pointer "*" + FunctionTypedef = 1 << 4, ///< Bitmask flag for a function typedef; when this is set, Pointer should also always be set }; /// Qualifiers and indirections on type. 0 is innermost. @@ -163,6 +164,8 @@ HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(int64_t); HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(uint64_t); HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(Halide::float16_t); HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(Halide::bfloat16_t); +HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(halide_task_t); +HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(halide_loop_task_t); HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(float); HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(double); HALIDE_DECLARE_EXTERN_STRUCT_TYPE(halide_buffer_t); @@ -196,11 +199,18 @@ template constexpr bool is_lvalue_reference = std::is_lvalue_reference::value; constexpr bool is_rvalue_reference = std::is_rvalue_reference::value; - using TBase = typename std::remove_pointer::type>::type; + using TNoRef = typename std::remove_reference::type; + using TNoRefNoPtr = typename std::remove_pointer::type; + constexpr bool is_function_pointer = std::is_pointer::value && + std::is_function::value; + + // Don't remove the pointer-ness from a function pointer. + using TBase = typename std::conditional::type; constexpr bool is_const = std::is_const::value; constexpr bool is_volatile = std::is_volatile::value; constexpr uint8_t modifiers = static_cast( + (is_function_pointer ? halide_handle_cplusplus_type::FunctionTypedef : 0) | (is_ptr ? halide_handle_cplusplus_type::Pointer : 0) | (is_const ? halide_handle_cplusplus_type::Const : 0) | (is_volatile ? halide_handle_cplusplus_type::Volatile : 0)); diff --git a/src/Util.cpp b/src/Util.cpp index 53f15fa14b2e..954f1378f726 100644 --- a/src/Util.cpp +++ b/src/Util.cpp @@ -313,7 +313,7 @@ std::string extract_namespaces(const std::string &name, std::vector return result; } -std::string extract_namespaces(const std::string &name) { +std::string strip_namespaces(const std::string &name) { std::vector unused; return extract_namespaces(name, unused); } @@ -661,6 +661,9 @@ size_t get_compiler_stack_size() { namespace Internal { +#ifdef HALIDE_INTERNAL_USING_ASAN +// nothing +#else namespace { // We can't reliably pass arguments through makecontext, because // the calling convention involves an invalid function pointer @@ -668,6 +671,7 @@ namespace { // platforms, so we use a thread local to pass arguments. thread_local void *run_with_large_stack_arg = nullptr; } // namespace +#endif void run_with_large_stack(const std::function &action) { if (stack_size.size == 0) { @@ -677,7 +681,6 @@ void run_with_large_stack(const std::function &action) { } #if _WIN32 - // Only exists for its address, which is used to compute remaining stack space. ULONG_PTR approx_stack_pos; @@ -719,6 +722,14 @@ void run_with_large_stack(const std::function &action) { #else // On posixy systems we have makecontext / swapcontext +#ifdef HALIDE_INTERNAL_USING_ASAN + // ... unless we are compiling under ASAN, in which case we + // will get a zillion warnings about ASAN not supporting makecontext/swapcontext + // and the possibility of false positives. Just skip the extra stack space, I guess? + action(); + return; +#else + #ifdef HALIDE_WITH_EXCEPTIONS struct Args { const std::function &run; @@ -783,6 +794,8 @@ void run_with_large_stack(const std::function &action) { } #endif +#endif // not ADDRESS_SANITIZER + #endif } diff --git a/src/Util.h b/src/Util.h index 6f551d4174d9..00489b1b3f34 100644 --- a/src/Util.h +++ b/src/Util.h @@ -45,6 +45,32 @@ #define HALIDE_NO_USER_CODE_INLINE HALIDE_NEVER_INLINE #endif +// Clang uses __has_feature() for sanitizers... +#if defined(__has_feature) +#if __has_feature(address_sanitizer) +#define HALIDE_INTERNAL_USING_ASAN +#endif +#if __has_feature(memory_sanitizer) +#define HALIDE_INTERNAL_USING_MSAN +#endif +#if __has_feature(thread_sanitizer) +#define HALIDE_INTERNAL_USING_TSAN +#endif +#if __has_feature(coverage_sanitizer) +#define HALIDE_INTERNAL_USING_COVSAN +#endif +#if __has_feature(undefined_behavior_sanitizer) +#define HALIDE_INTERNAL_USING_UBSAN +#endif +#endif + +// ...but GCC/MSVC don't like __has_feature, so handle them separately. +// (Only AddressSanitizer for now, not sure if any others are well-supported +// outside of Clang. +#if defined(__SANITIZE_ADDRESS__) && !defined(HALIDE_INTERNAL_USING_ASAN) +#define HALIDE_INTERNAL_USING_ASAN +#endif + namespace Halide { /** Load a plugin in the form of a dynamic library (e.g. for custom autoschedulers). @@ -208,8 +234,8 @@ struct all_are_convertible : meta_and...> {}; /** Returns base name and fills in namespaces, outermost one first in vector. */ std::string extract_namespaces(const std::string &name, std::vector &namespaces); -/** Overload that returns base name only */ -std::string extract_namespaces(const std::string &name); +/** Like extract_namespaces(), but strip and discard the namespaces, returning base name only */ +std::string strip_namespaces(const std::string &name); struct FileStat { uint64_t file_size; @@ -370,14 +396,13 @@ void halide_toc_impl(const char *file, int line); // regarding 'bool' in some compliation configurations. template struct StaticCast { - template::value>::type * = nullptr> - inline constexpr static TO2 value(const FROM &from) { - return static_cast(from); - } - - template::value>::type * = nullptr> - inline constexpr static TO2 value(const FROM &from) { - return from != 0; + template + inline constexpr static TO value(const FROM &from) { + if constexpr (std::is_same::value) { + return from != 0; + } else { + return static_cast(from); + } } }; @@ -386,19 +411,21 @@ struct StaticCast { // or dropping of fractional parts). template struct IsRoundtrippable { - template::value>::type * = nullptr> - inline constexpr static bool value(const FROM &from) { - return false; - } - - template::value && std::is_arithmetic::value && std::is_arithmetic::value && !std::is_same::value>::type * = nullptr> + template inline constexpr static bool value(const FROM &from) { - return StaticCast::value(StaticCast::value(from)) == from; - } - - template::value && !(std::is_arithmetic::value && std::is_arithmetic::value && !std::is_same::value)>::type * = nullptr> - inline constexpr static bool value(const FROM &from) { - return true; + if constexpr (std::is_convertible::value) { + if constexpr (std::is_arithmetic::value && + std::is_arithmetic::value && + !std::is_same::value) { + const TO to = static_cast(from); + const FROM roundtripped = static_cast(to); + return roundtripped == from; + } else { + return true; + } + } else { + return false; + } } }; diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index e2960739ed86..7dcd79d24664 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -527,6 +527,16 @@ class VectorSubs : public IRMutator { } } + Expr visit(const Reinterpret *op) override { + Expr value = mutate(op->value); + if (value.same_as(op->value)) { + return op; + } else { + Type t = op->type.with_lanes(value.type().lanes()); + return Reinterpret::make(t, value); + } + } + string get_widened_var_name(const string &name) { return name + ".widened." + vectorized_vars.back().name; } diff --git a/src/WasmExecutor.cpp b/src/WasmExecutor.cpp index 68de5585de3b..b5750a30eb76 100644 --- a/src/WasmExecutor.cpp +++ b/src/WasmExecutor.cpp @@ -1842,7 +1842,9 @@ void wasm_jit_malloc_callback(const v8::FunctionCallbackInfo &args) { size_t size = args[0]->Int32Value(context).ToChecked() + kExtraMallocSlop; wasm32_ptr_t p = v8_WasmMemoryObject_malloc(context, size); - if (p) { p += kExtraMallocSlop; } + if (p) { + p += kExtraMallocSlop; + } args.GetReturnValue().Set(load_scalar(context, p)); } @@ -1851,7 +1853,9 @@ void wasm_jit_free_callback(const v8::FunctionCallbackInfo &args) { HandleScope scope(isolate); Local context = isolate->GetCurrentContext(); wasm32_ptr_t p = args[0]->Int32Value(context).ToChecked(); - if (p) { p -= kExtraMallocSlop; } + if (p) { + p -= kExtraMallocSlop; + } v8_WasmMemoryObject_free(context, p); } @@ -2116,7 +2120,7 @@ void add_extern_callbacks(const Local &context, continue; } - TrampolineFn trampoline_fn; + TrampolineFn trampoline_fn = nullptr; std::vector arg_types; if (!build_extern_arg_types(fn_name, jit_externs, trampolines, trampoline_fn, arg_types)) { internal_error << "Missing fn_name " << fn_name; @@ -2280,7 +2284,7 @@ struct WasmModuleContents { const std::map &jit_externs, const std::vector &extern_deps); - int run(const void **args); + int run(const void *const *args); ~WasmModuleContents() = default; }; @@ -2517,7 +2521,7 @@ WasmModuleContents::WasmModuleContents( #endif } -int WasmModuleContents::run(const void **args) { +int WasmModuleContents::run(const void *const *args) { #if WITH_WABT const auto &module_desc = module->desc(); @@ -2726,7 +2730,7 @@ WasmModule WasmModule::compile( } /** Run generated previously compiled wasm code with a set of arguments. */ -int WasmModule::run(const void **args) { +int WasmModule::run(const void *const *args) { internal_assert(contents.defined()); return contents->run(args); } diff --git a/src/WasmExecutor.h b/src/WasmExecutor.h index 07325a166729..f0871507ad02 100644 --- a/src/WasmExecutor.h +++ b/src/WasmExecutor.h @@ -11,16 +11,21 @@ */ #include "Argument.h" -#include "JITModule.h" #include "Parameter.h" #include "Type.h" +#include +#include +#include + namespace Halide { +struct JITExtern; struct Target; namespace Internal { +struct JITModule; struct WasmModuleContents; /** Handle to compiled wasm code which can be called later. */ @@ -39,7 +44,7 @@ struct WasmModule { const std::vector &extern_deps); /** Run generated previously compiled wasm code with a set of arguments. */ - int run(const void **args); + int run(const void *const *args); }; } // namespace Internal diff --git a/src/WrapCalls.h b/src/WrapCalls.h index f73ec2bf1f83..e54244bcf9f8 100644 --- a/src/WrapCalls.h +++ b/src/WrapCalls.h @@ -15,7 +15,7 @@ namespace Internal { class Function; /** Replace every call to wrapped Functions in the Functions' definitions with - * call to their wrapper functions. */ + * call to their wrapper functions. */ std::map wrap_func_calls(const std::map &env); } // namespace Internal diff --git a/src/autoschedulers/adams2019/AutoSchedule.cpp b/src/autoschedulers/adams2019/AutoSchedule.cpp index cdb2e14cc0ea..c8d40414b2ee 100644 --- a/src/autoschedulers/adams2019/AutoSchedule.cpp +++ b/src/autoschedulers/adams2019/AutoSchedule.cpp @@ -20,26 +20,31 @@ Environment variables used (directly or indirectly): - HL_BEAM_SIZE - Beam size to use in the beam search. Defaults to 32. Use 1 to get a greedy search instead. + HL_DEBUG_AUTOSCHEDULE + If set, is used for the debug log level for auto-schedule generation (overriding the + value of HL_DEBUG_CODEGEN, if any). - HL_CYOS - "Choose-your-own-schedule". If set to 1, lets you navigate the search tree by hand in the terminal. Whee! This is for debugging the autoscheduler. + HL_PERMIT_FAILED_UNROLL + Set to 1 to tell Halide not to freak out if we try to unroll a loop that doesn't have a constant extent. Should generally not be necessary, but sometimes the autoscheduler's model for what will and will not turn into a constant during lowering is inaccurate, because Halide isn't perfect at constant-folding. - HL_FEATURE_FILE -> output - *** DEPRECATED *** use the 'featurization' output from Generator instead - Write out a training featurization for the selected schedule into this file. - Needs to be converted to a sample file with the runtime using featurization_to_sample before it can be used to train. +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API - HL_MACHINE_PARAMS - An architecture description string. Used by Halide master to configure the cost model. We only use the first term. Set it to the number of cores to target. + Most of the settings in this Autoscheduler are controlled by the values specified via + an `autoscheduler.fieldname` GeneratorParam, as listed in the Adams2019Params struct; + this is the preferred way to set these. - HL_PERMIT_FAILED_UNROLL - Set to 1 to tell Halide not to freak out if we try to unroll a loop that doesn't have a constant extent. Should generally not be necessary, but sometimes the autoscheduler's model for what will and will not turn into a constant during lowering is inaccurate, because Halide isn't perfect at constant-folding. + For now, however, you can (instead) control these settings via env vars; + doing so requires that you compile all of Halide with HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + defined. (Note that this ability is deprecated, and likely to be removed in Halide 16.) + + That said, here are the (legacy) env vars you can still use when HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + is defined: + + HL_BEAM_SIZE + Beam size to use in the beam search. Defaults to 32. Use 1 to get a greedy search instead. - HL_SCHEDULE_FILE - *** DEPRECATED *** use the 'schedule' output from Generator instead - Write out a human-and-machine readable block of scheduling source code for the selected schedule into this file. + HL_CYOS + "Choose-your-own-schedule". If set to 1, lets you navigate the search tree by hand in the terminal. Whee! This is for debugging the autoscheduler. HL_RANDOM_DROPOUT percent chance of accepting each state in the beam. Normalized by the number of decisions made, so 5 would be there's a 5 percent chance of never rejecting any states. @@ -54,10 +59,6 @@ HL_NO_SUBTILING If set to 1, limits the search space to that of Mullapudi et al. - HL_DEBUG_AUTOSCHEDULE - If set, is used for the debug log level for auto-schedule generation (overriding the - value of HL_DEBUG_CODEGEN, if any). - HL_AUTOSCHEDULE_MEMORY_LIMIT If set, only consider schedules that allocate at most this much memory (measured in bytes). @@ -69,8 +70,20 @@ If set, then tiling sizes are not cached across passes. (see Cache.h for more information) - TODO: expose these settings by adding some means to pass args to - generator plugins instead of environment vars. +#endif + +#ifdef HALIDE_AUTOSCHEDULER_ALLOW_CYOS + + HL_CYOS + "Choose-your-own-schedule". + + If set to 1, lets you navigate the search tree by hand in the terminal. + Whee! This is for debugging the autoscheduler. Since it is generally only + for use by developers/maintainers of this autoscheduler, it defaults + to being omitted entirely unless you build Halide with HALIDE_AUTOSCHEDULER_ALLOW_CYOS defined. + Even then, you must *also* set the env var to 1 to make use of it. + +#endif */ #include "HalidePlugin.h" @@ -96,11 +109,11 @@ #include "Halide.h" #include "LoopNest.h" #include "NetworkSize.h" +#include "ParamParser.h" #include "PerfectHashMap.h" #include "State.h" #include "Timer.h" - #ifdef _WIN32 #include #define _isatty isatty; @@ -118,66 +131,73 @@ struct ProgressBar { if (!draw_progress_bar) { return; } + auto &os = aslog(ProgressBarLogLevel).get_ostream(); counter++; const int bits = 11; if (counter & ((1 << bits) - 1)) { return; } const int pos = (int)(progress * 78); - aslog(0) << "["; + os << "["; for (int j = 0; j < 78; j++) { if (j < pos) { - aslog(0) << "."; + os << "."; } else if (j - 1 < pos) { - aslog(0) << "/-\\|"[(counter >> bits) % 4]; + os << "/-\\|"[(counter >> bits) % 4]; } else { - aslog(0) << " "; + os << " "; } } - aslog(0) << "]"; + os << "]"; for (int j = 0; j < 80; j++) { - aslog(0) << "\b"; + os << "\b"; } } void clear() { if (counter) { + auto &os = aslog(ProgressBarLogLevel).get_ostream(); for (int j = 0; j < 80; j++) { - aslog(0) << " "; + os << " "; } for (int j = 0; j < 80; j++) { - aslog(0) << "\b"; + os << "\b"; } } } private: uint32_t counter = 0; - const bool draw_progress_bar = isatty(2); + static constexpr int ProgressBarLogLevel = 1; + const bool draw_progress_bar = isatty(2) && aslog::aslog_level() >= ProgressBarLogLevel; }; -// Get the HL_RANDOM_DROPOUT environment variable. Purpose of this is described above. -uint32_t get_dropout_threshold() { - string random_dropout_str = get_env_variable("HL_RANDOM_DROPOUT"); - if (!random_dropout_str.empty()) { - return atoi(random_dropout_str.c_str()); - } else { - return 100; +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API +template +T get_scalar_env_var(const char *nm, T def = T()) { + auto str = get_env_variable(nm); + if (str.empty()) { + return def; } + std::istringstream iss(str); + T t; + iss >> t; + user_assert(!iss.fail() && iss.get() == EOF) << "Unable to parse: " << str; + return t; } +#endif // Decide whether or not to drop a beam search state. Used for // randomly exploring the search tree for autotuning and to generate // training data. -bool random_dropout(std::mt19937 &rng, size_t num_decisions) { - static double random_dropout_threshold = get_dropout_threshold(); - if (random_dropout_threshold >= 100) { +bool random_dropout(const Adams2019Params ¶ms, std::mt19937 &rng, size_t num_decisions) { + if (params.random_dropout >= 100) { return false; } // The random dropout threshold is the chance that we operate // entirely greedily and never discard anything. - double t = random_dropout_threshold; + double t = params.random_dropout; t /= 100; t = std::pow(t, 1.0f / num_decisions); t *= 100; @@ -187,7 +207,6 @@ bool random_dropout(std::mt19937 &rng, size_t num_decisions) { return drop_it; } - // A priority queue of states, sorted according to increasing // cost. Never shrinks, to avoid reallocations. // Can't use std::priority_queue because it doesn't support unique_ptr. @@ -255,7 +274,7 @@ class StateQueue { // Configure a cost model to process a specific pipeline. void configure_pipeline_features(const FunctionDAG &dag, - const MachineParams ¶ms, + const Adams2019Params ¶ms, CostModel *cost_model) { cost_model->reset(); cost_model->set_pipeline_features(dag, params); @@ -264,11 +283,9 @@ void configure_pipeline_features(const FunctionDAG &dag, // A single pass of coarse-to-fine beam search. IntrusivePtr optimal_schedule_pass(FunctionDAG &dag, const vector &outputs, - const MachineParams ¶ms, + const Adams2019Params ¶ms, CostModel *cost_model, std::mt19937 &rng, - int beam_size, - int64_t memory_limit, int pass_idx, int num_passes, ProgressBar &tick, @@ -292,15 +309,11 @@ IntrusivePtr optimal_schedule_pass(FunctionDAG &dag, std::function &&)> enqueue_new_children = [&](IntrusivePtr &&s) { - // aslog(0) << "\n** Generated child: "; - // s->dump(); - // s->calculate_cost(dag, params, nullptr, true); - // Each child should have one more decision made than its parent state. internal_assert(s->num_decisions_made == s->parent->num_decisions_made + 1); - int progress = s->num_decisions_made * beam_size + expanded; - size_t max_progress = dag.nodes.size() * beam_size * 2; + int progress = s->num_decisions_made * params.beam_size + expanded; + size_t max_progress = dag.nodes.size() * params.beam_size * 2; // Update the progress bar tick.set(double(progress) / max_progress); @@ -310,7 +323,9 @@ IntrusivePtr optimal_schedule_pass(FunctionDAG &dag, q.emplace(std::move(s)); }; +#ifdef HALIDE_AUTOSCHEDULER_ALLOW_CYOS string cyos_str = get_env_variable("HL_CYOS"); +#endif // This loop is beam search over the sequence of decisions to make. for (int i = 0;; i++) { @@ -318,37 +333,37 @@ IntrusivePtr optimal_schedule_pass(FunctionDAG &dag, q.swap(pending); if (pending.empty()) { - if ((false) && beam_size < 1000) { // Intentional dead code. Extra parens to pacify clang-tidy. + if ((false) && params.beam_size < 1000) { // Intentional dead code. Extra parens to pacify clang-tidy. // Total mortality. Double the beam size and // restart. Disabled for now because total mortality // may indicate a bug. + Adams2019Params params2 = params; + params2.beam_size *= 2; return optimal_schedule_pass(dag, outputs, - params, + params2, cost_model, rng, - beam_size * 2, - memory_limit, pass_idx, num_passes, tick, permitted_hashes, cache); } else { - internal_error << "Ran out of legal states with beam size " << beam_size << "\n"; + internal_error << "Ran out of legal states with beam size " << params.beam_size << "\n"; } } - if ((int)pending.size() > beam_size * 10000) { - aslog(0) << "Warning: Huge number of states generated (" << pending.size() << ").\n"; + if ((int)pending.size() > params.beam_size * 10000) { + aslog(1) << "*** Warning: Huge number of states generated (" << pending.size() << ").\n"; } expanded = 0; - while (expanded < beam_size && !pending.empty()) { + while (expanded < params.beam_size && !pending.empty()) { IntrusivePtr state{pending.pop()}; - if (beam_size > 1 && num_passes > 1) { + if (params.beam_size > 1 && num_passes > 1) { // We are doing coarse-to-fine beam search using the // hashing strategy mentioned in the paper. // @@ -386,7 +401,7 @@ IntrusivePtr optimal_schedule_pass(FunctionDAG &dag, } // Random dropout - if (pending.size() > 1 && random_dropout(rng, dag.nodes.size() * 2)) { + if (pending.size() > 1 && random_dropout(params, rng, dag.nodes.size() * 2)) { continue; } @@ -403,7 +418,7 @@ IntrusivePtr optimal_schedule_pass(FunctionDAG &dag, // there are more coarse-to-fine passes yet to come. if (pass_idx + 1 < num_passes) { int blessed = 0; - while (state->cost <= 1.2 * best->cost && blessed < beam_size) { + while (state->cost <= 1.2 * best->cost && blessed < params.beam_size) { const State *s = state.get(); while (s) { uint64_t h1 = s->structural_hash(pass_idx); @@ -421,7 +436,7 @@ IntrusivePtr optimal_schedule_pass(FunctionDAG &dag, return best; } - state->generate_children(dag, params, cost_model, memory_limit, enqueue_new_children, cache); + state->generate_children(dag, params, cost_model, enqueue_new_children, cache); expanded++; } @@ -434,45 +449,44 @@ IntrusivePtr optimal_schedule_pass(FunctionDAG &dag, q.resort(); } +#ifdef HALIDE_AUTOSCHEDULER_ALLOW_CYOS if (cyos_str == "1") { // The user has set HL_CYOS, and wants to navigate the // search space manually. Discard everything in the queue // except for the user-chosen option. - aslog(0) << "\n--------------------\n"; - aslog(0) << "Select a schedule:\n"; + std::cout << "\n--------------------\n"; + std::cout << "Select a schedule:\n"; for (int choice_label = (int)q.size() - 1; choice_label >= 0; choice_label--) { auto state = q[choice_label]; - aslog(0) << "\n[" << choice_label << "]:\n"; - state->dump(); - aslog(0) << "\nFeature vector: " << state->dump(true) << "\n"; - state->calculate_cost(dag, params, cost_model, cache->options, memory_limit, true); + std::cout << "\n[" << choice_label << "]:\n"; + state->dump(std::cout); + constexpr int verbosity_level = 0; // always + state->calculate_cost(dag, params, cost_model, cache->options, verbosity_level); } cost_model->evaluate_costs(); // Select next partial schedule to expand. int selection = -1; while (selection < 0 || selection >= (int)q.size()) { - aslog(0) << "\nEnter selection: "; + std::cout << "\nEnter selection: "; std::cin >> selection; } auto selected = q[selection]; - selected->dump(); - aslog(0) << "\nFeature vector: " << selected->dump(true) << "\n"; + selected->dump(std::cout); q.clear(); q.emplace(std::move(selected)); } +#endif } } // Performance coarse-to-fine beam search and return the best state found. IntrusivePtr optimal_schedule(FunctionDAG &dag, const vector &outputs, - const MachineParams ¶ms, + const Adams2019Params ¶ms, CostModel *cost_model, std::mt19937 &rng, - int beam_size, - int64_t memory_limit, const CachingOptions &options) { IntrusivePtr best; @@ -483,14 +497,16 @@ IntrusivePtr optimal_schedule(FunctionDAG &dag, Cache cache(options, dag.nodes.size()); // If the beam size is one, it's pointless doing multiple passes. - int num_passes = (beam_size == 1) ? 1 : 5; + int num_passes = (params.beam_size == 1) ? 1 : 5; +#ifdef HALIDE_AUTOSCHEDULER_ALLOW_CYOS string cyos_str = get_env_variable("HL_CYOS"); if (cyos_str == "1") { // If the user is manually navigating the search space, don't // ask them to do more than one pass. num_passes = 1; } +#endif string num_passes_str = get_env_variable("HL_NUM_PASSES"); if (!num_passes_str.empty()) { @@ -498,27 +514,29 @@ IntrusivePtr optimal_schedule(FunctionDAG &dag, num_passes = std::atoi(num_passes_str.c_str()); } - Timer timer; - for (int i = 0; i < num_passes; i++) { ProgressBar tick; Timer timer; auto pass = optimal_schedule_pass(dag, outputs, params, cost_model, - rng, beam_size, memory_limit, - i, num_passes, tick, permitted_hashes, &cache); + rng, i, num_passes, tick, permitted_hashes, &cache); std::chrono::duration total_time = timer.elapsed(); auto milli = std::chrono::duration_cast(total_time).count(); tick.clear(); - if (aslog::aslog_level() == 0) { - aslog(0) << "Pass " << i << " of " << num_passes << ", cost: " << pass->cost << ", time (ms): " << milli << "\n"; - } else { - aslog(0) << "Pass " << i << " result: "; - pass->dump(); + switch (aslog::aslog_level()) { + case 0: + // Silence + break; + case 1: + aslog(1) << "Pass " << i << " of " << num_passes << ", cost: " << pass->cost << ", time (ms): " << milli << "\n"; + break; + default: + aslog(2) << "Pass " << i << " result: "; + pass->dump(aslog(2).get_ostream()); } if (i == 0 || pass->cost < best->cost) { @@ -528,12 +546,12 @@ IntrusivePtr optimal_schedule(FunctionDAG &dag, } } - std::chrono::duration total_time = timer.elapsed(); - auto milli = std::chrono::duration_cast(total_time).count(); - - aslog(0) << "Best cost: " << best->cost << "\n"; + aslog(1) << "Best cost: " << best->cost << "\n"; - aslog(0) << "Execution time: " << milli << " ms\n\n"; + if (options.cache_blocks) { + aslog(1) << "Cache (block) hits: " << cache.cache_hits << "\n"; + aslog(1) << "Cache (block) misses: " << cache.cache_misses << "\n"; + } return best; } @@ -544,47 +562,36 @@ int State::cost_calculations = 0; // The main entrypoint to generate a schedule for a pipeline. void generate_schedule(const std::vector &outputs, const Target &target, - const MachineParams ¶ms, + const Adams2019Params ¶ms, AutoSchedulerResults *auto_scheduler_results) { - aslog(0) << "generate_schedule for target=" << target.to_string() << "\n"; + aslog(1) << "generate_schedule for target=" << target.to_string() << "\n"; + aslog(1) << "Adams2019.parallelism:" << params.parallelism << "\n"; + aslog(1) << "Adams2019.beam_size:" << params.beam_size << "\n"; + aslog(1) << "Adams2019.random_dropout:" << params.random_dropout << "\n"; + aslog(1) << "Adams2019.random_dropout_seed:" << params.random_dropout_seed << "\n"; + aslog(1) << "Adams2019.weights_path:" << params.weights_path << "\n"; + aslog(1) << "Adams2019.disable_subtiling:" << params.disable_subtiling << "\n"; + aslog(1) << "Adams2019.disable_memoized_features:" << params.disable_memoized_features << "\n"; + aslog(1) << "Adams2019.disable_memoized_blocks:" << params.disable_memoized_blocks << "\n"; + aslog(1) << "Adams2019.memory_limit:" << params.memory_limit << "\n"; // Start a timer HALIDE_TIC; State::cost_calculations = 0; - // Get the seed for random dropout - string seed_str = get_env_variable("HL_SEED"); - std::cout << "HL_SEED is " << seed_str << " from inside Adams2019" << std::endl; - // Or use the time, if not set. - int seed = (int)time(nullptr); - if (!seed_str.empty()) { - seed = atoi(seed_str.c_str()); - } - aslog(1) << "Dropout seed = " << seed << "\n"; - std::mt19937 rng((uint32_t)seed); - - // Get the beam size - string beam_size_str = get_env_variable("HL_BEAM_SIZE"); - // Defaults to 32 - size_t beam_size = 32; - if (!beam_size_str.empty()) { - beam_size = atoi(beam_size_str.c_str()); - } + std::mt19937 rng((uint32_t)params.random_dropout_seed); - string weights_in_path = get_env_variable("HL_WEIGHTS_DIR"); + string weights_in_path = params.weights_path; string weights_out_path; // deliberately empty string randomize_weights_str = get_env_variable("HL_RANDOMIZE_WEIGHTS"); bool randomize_weights = randomize_weights_str == "1"; - string memory_limit_str = get_env_variable("HL_AUTOSCHEDULE_MEMORY_LIMIT"); - int64_t memory_limit = memory_limit_str.empty() ? (uint64_t)(-1) : std::atoll(memory_limit_str.c_str()); - // Analyse the Halide algorithm and construct our abstract representation of it - FunctionDAG dag(outputs, params, target); - if (aslog::aslog_level() > 0) { - dag.dump(); + FunctionDAG dag(outputs, target); + if (aslog::aslog_level() >= 2) { + dag.dump(aslog(2).get_ostream()); } // Construct a cost model to use to evaluate states. Currently we @@ -596,10 +603,10 @@ void generate_schedule(const std::vector &outputs, IntrusivePtr optimal; // Options generated from environment variables, decide whether or not to cache features and/or tilings. - CachingOptions cache_options = CachingOptions::MakeOptionsFromEnviron(); + CachingOptions cache_options = CachingOptions::MakeOptionsFromParams(params); // Run beam search - optimal = optimal_schedule(dag, outputs, params, cost_model.get(), rng, beam_size, memory_limit, cache_options); + optimal = optimal_schedule(dag, outputs, params, cost_model.get(), rng, cache_options); HALIDE_TOC; @@ -609,76 +616,75 @@ void generate_schedule(const std::vector &outputs, aslog(1) << "** Optimal schedule:\n"; // Just to get the debugging prints to fire - optimal->calculate_cost(dag, params, cost_model.get(), cache_options, memory_limit, aslog::aslog_level() > 0); + optimal->calculate_cost(dag, params, cost_model.get(), cache_options, /*verbosity_level*/ 1); // Apply the schedules to the pipeline optimal->apply_schedule(dag, params); // Print out the schedule - if (aslog::aslog_level() > 0) { - optimal->dump(); - } - - // aslog(0) << "Source:" << optimal->schedule_source << "\n\n\n"; - - string schedule_file = get_env_variable("HL_SCHEDULE_FILE"); - if (!schedule_file.empty()) { - user_warning << "HL_SCHEDULE_FILE is deprecated; use the schedule output from Generator instead\n"; - aslog(1) << "Writing schedule to " << schedule_file << "...\n"; - std::ofstream f(schedule_file); - f << "// --- BEGIN machine-generated schedule\n" - << optimal->schedule_source - << "// --- END machine-generated schedule\n"; - f.close(); - internal_assert(!f.fail()) << "Failed to write " << schedule_file; - } - string python_schedule_file = get_env_variable("HL_PYTHON_SCHEDULE_FILE"); - if (!python_schedule_file.empty()) { - user_warning << "HL_PYTHON_SCHEDULE_FILE is deprecated; use the schedule output from Generator instead\n"; - aslog(1) << "Writing schedule to " << python_schedule_file << "...\n"; - std::ofstream f(python_schedule_file); - f << "# --- BEGIN machine-generated schedule\n" - << optimal->python_schedule_source - << "# --- END machine-generated schedule\n"; - f.close(); - internal_assert(!f.fail()) << "Failed to write " << python_schedule_file; - } - // Save the featurization, so that we can use this schedule as - // training data (once we've benchmarked it). - string feature_file = get_env_variable("HL_FEATURE_FILE"); - if (!feature_file.empty()) { - user_warning << "HL_FEATURE_FILE is deprecated; use the featurization output from Generator instead\n"; - std::ofstream binfile(feature_file, std::ios::binary | std::ios_base::trunc); - std::ofstream feature_file_index(feature_file + ".index"); - optimal->save_featurization(dag, params, cache_options, binfile, feature_file_index); - binfile.close(); - feature_file_index.close(); - internal_assert(!binfile.fail()) << "Failed to write " << feature_file; + if (aslog::aslog_level() >= 2) { + optimal->dump(aslog(2).get_ostream()); } if (auto_scheduler_results) { +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API auto_scheduler_results->scheduler_name = "Adams2019"; +#endif auto_scheduler_results->schedule_source = optimal->schedule_source; - auto_scheduler_results->python_schedule_source = optimal->python_schedule_source; - auto_scheduler_results->path_featurization = optimal->dump(true); { - std::ostringstream out, index_out; - optimal->save_featurization(dag, params, cache_options, out, index_out); + std::ostringstream out; + optimal->save_featurization(dag, params, cache_options, out); auto_scheduler_results->featurization.resize(out.str().size()); memcpy(auto_scheduler_results->featurization.data(), out.str().data(), out.str().size()); - auto_scheduler_results->featurization_index = index_out.str(); } } } struct Adams2019 { - void operator()(const Pipeline &p, const Target &target, const MachineParams ¶ms, AutoSchedulerResults *results) { +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + void operator()(const Pipeline &p, const Target &target, const MachineParams ¶ms_in, AutoSchedulerResults *results) { + std::vector outputs; + for (const Func &f : p.outputs()) { + outputs.push_back(f.function()); + } + Adams2019Params params; + params.parallelism = params_in.parallelism; + params.beam_size = get_scalar_env_var("HL_BEAM_SIZE", 32); + params.random_dropout = get_scalar_env_var("HL_RANDOM_DROPOUT", 100); + params.random_dropout_seed = get_scalar_env_var("HL_SEED", (int)time(nullptr)); + params.weights_path = get_scalar_env_var("HL_WEIGHTS_DIR"); + params.disable_subtiling = get_scalar_env_var("HL_NO_SUBTILING", 0); + params.disable_memoized_features = get_scalar_env_var("HL_DISABLE_MEMOIZED_FEATURES", 0); + params.disable_memoized_blocks = get_scalar_env_var("HL_DISABLE_MEMOIZED_BLOCKS", 0); + params.memory_limit = get_scalar_env_var("HL_AUTOSCHEDULE_MEMORY_LIMIT", -1); + Autoscheduler::generate_schedule(outputs, target, params, results); + } +#else + void operator()(const Pipeline &p, const Target &target, const AutoschedulerParams ¶ms_in, AutoSchedulerResults *results) { + internal_assert(params_in.name == "Adams2019"); + std::vector outputs; for (const Func &f : p.outputs()) { outputs.push_back(f.function()); } + Adams2019Params params; + { + ParamParser parser(params_in.extra); + parser.parse("parallelism", ¶ms.parallelism); + parser.parse("beam_size", ¶ms.beam_size); + parser.parse("random_dropout", ¶ms.random_dropout); + parser.parse("random_dropout_seed", ¶ms.random_dropout_seed); + parser.parse("weights_path", ¶ms.weights_path); + parser.parse("disable_subtiling", ¶ms.disable_subtiling); + parser.parse("disable_memoized_features", ¶ms.disable_memoized_features); + parser.parse("disable_memoized_blocks", ¶ms.disable_memoized_blocks); + parser.parse("memory_limit", ¶ms.memory_limit); + parser.finish(); + } Autoscheduler::generate_schedule(outputs, target, params, results); + results->autoscheduler_params = params_in; } +#endif }; REGISTER_AUTOSCHEDULER(Adams2019) @@ -686,15 +692,13 @@ REGISTER_AUTOSCHEDULER(Adams2019) // An alternative entrypoint for other uses void find_and_apply_schedule(FunctionDAG &dag, const std::vector &outputs, - const MachineParams ¶ms, + const Adams2019Params ¶ms, CostModel *cost_model, - int beam_size, - int64_t memory_limit, StageMap *schedule_features) { std::mt19937 rng(12345); - CachingOptions cache_options = CachingOptions::MakeOptionsFromEnviron(); - IntrusivePtr optimal = optimal_schedule(dag, outputs, params, cost_model, rng, beam_size, memory_limit, cache_options); + CachingOptions cache_options = CachingOptions::MakeOptionsFromParams(params); + IntrusivePtr optimal = optimal_schedule(dag, outputs, params, cost_model, rng, cache_options); // Apply the schedules optimal->apply_schedule(dag, params); diff --git a/src/autoschedulers/adams2019/AutoSchedule.h b/src/autoschedulers/adams2019/AutoSchedule.h index b7a76dc67e50..270ca7a24641 100644 --- a/src/autoschedulers/adams2019/AutoSchedule.h +++ b/src/autoschedulers/adams2019/AutoSchedule.h @@ -11,7 +11,7 @@ namespace Autoscheduler { typedef PerfectHashMap StageMapOfScheduleFeatures; -void find_and_apply_schedule(FunctionDAG &dag, const std::vector &outputs, const MachineParams ¶ms, +void find_and_apply_schedule(FunctionDAG &dag, const std::vector &outputs, const Adams2019Params ¶ms, CostModel *cost_model, int beam_size, StageMapOfScheduleFeatures *schedule_features); } // namespace Autoscheduler diff --git a/src/autoschedulers/adams2019/CMakeLists.txt b/src/autoschedulers/adams2019/CMakeLists.txt index dd1c11f1f3cf..5b4547de7143 100644 --- a/src/autoschedulers/adams2019/CMakeLists.txt +++ b/src/autoschedulers/adams2019/CMakeLists.txt @@ -1,4 +1,3 @@ -project(adams2019) ## # Resources for the autoscheduler library ## @@ -23,12 +22,11 @@ add_halide_library(train_cost_model FROM cost_model.generator # retrain_cost_model add_executable(retrain_cost_model - ASLog.cpp DefaultCostModel.cpp Weights.cpp retrain_cost_model.cpp ${WF_CPP}) -target_link_libraries(retrain_cost_model PRIVATE cost_model train_cost_model Halide::Halide Halide::Plugin) +target_link_libraries(retrain_cost_model PRIVATE ASLog cost_model train_cost_model Halide::Halide Halide::Plugin) ## # Main autoscheduler library @@ -36,7 +34,6 @@ target_link_libraries(retrain_cost_model PRIVATE cost_model train_cost_model Hal add_autoscheduler(NAME Adams2019 SOURCES - ASLog.cpp AutoSchedule.cpp Cache.cpp DefaultCostModel.cpp @@ -46,7 +43,7 @@ add_autoscheduler(NAME Adams2019 Weights.cpp ${WF_CPP}) -target_link_libraries(Halide_Adams2019 PRIVATE cost_model train_cost_model) +target_link_libraries(Halide_Adams2019 PRIVATE ASLog ParamParser cost_model train_cost_model) ## # Tests and demos @@ -71,7 +68,7 @@ add_test(NAME demo_apps_autoscheduler set_tests_properties(demo_apps_autoscheduler PROPERTIES - LABELS Adams2019 + LABELS "Adams2019;auto_schedule" ENVIRONMENT "HL_TARGET=${Halide_TARGET}") # ================================================================= @@ -92,7 +89,7 @@ add_test(NAME demo_included_schedule_file set_tests_properties(demo_included_schedule_file PROPERTIES - LABELS Adams2019 + LABELS "Adams2019;auto_schedule" ENVIRONMENT "HL_TARGET=${Halide_TARGET}") # ==================================================== @@ -115,10 +112,10 @@ if (BUILD_SHARED_LIBS) target_link_libraries(test_apps_autoscheduler PRIVATE Halide::Halide Halide::Tools ${CMAKE_DL_LIBS}) add_test(NAME test_apps_autoscheduler - COMMAND test_apps_autoscheduler $) + COMMAND test_apps_autoscheduler $ ${CMAKE_CURRENT_SOURCE_DIR}/baseline.weights) set_tests_properties(test_apps_autoscheduler PROPERTIES - LABELS Adams2019 + LABELS "Adams2019;multithreaded;auto_schedule" ENVIRONMENT "LD_LIBRARY_PATH=$:$ENV{LD_LIBRARY_PATH};HL_TARGET=${Halide_TARGET}") endif () @@ -129,16 +126,16 @@ add_executable(test_perfect_hash_map test_perfect_hash_map.cpp) add_test(NAME test_perfect_hash_map COMMAND test_perfect_hash_map) set_tests_properties(test_perfect_hash_map PROPERTIES - LABELS Adams2019 + LABELS "Adams2019;auto_schedule" ENVIRONMENT "HL_TARGET=${Halide_TARGET}") ## -add_executable(test_function_dag test_function_dag.cpp FunctionDAG.cpp ASLog.cpp) -target_link_libraries(test_function_dag PRIVATE Halide::Halide Halide::Tools Halide::Plugin) +add_executable(test_function_dag test_function_dag.cpp FunctionDAG.cpp) +target_link_libraries(test_function_dag PRIVATE ASLog Halide::Halide Halide::Tools Halide::Plugin) add_test(NAME test_function_dag COMMAND test_function_dag) set_tests_properties(test_function_dag PROPERTIES - LABELS Adams2019 + LABELS "Adams2019;auto_schedule" ENVIRONMENT "HL_TARGET=${Halide_TARGET}") diff --git a/src/autoschedulers/adams2019/Cache.cpp b/src/autoschedulers/adams2019/Cache.cpp index 8001a938e9e3..67aa9b3eccb7 100644 --- a/src/autoschedulers/adams2019/Cache.cpp +++ b/src/autoschedulers/adams2019/Cache.cpp @@ -1,28 +1,17 @@ #include "Cache.h" #include "LoopNest.h" #include "State.h" -#include -using namespace std; namespace Halide { namespace Internal { namespace Autoscheduler { -bool use_memoized_features() { - return get_env_variable("HL_DISABLE_MEMOIZED_FEATURES") != "1"; -} - -bool is_memoize_blocks_enabled() { - return get_env_variable("HL_DISABLE_MEMOIZED_BLOCKS") != "1"; -} - bool Cache::add_memoized_blocks(const State *state, std::function &&)> &accept_child, const FunctionDAG::Node *node, int &num_children, const FunctionDAG &dag, - const MachineParams ¶ms, - CostModel *cost_model, - int64_t memory_limit) const { + const Adams2019Params ¶ms, + CostModel *cost_model) const { if (!options.cache_blocks || !memoized_compute_root_blocks.contains(node)) { // either memoization is turned off, or we haven't cached this node yet. return false; @@ -71,7 +60,7 @@ bool Cache::add_memoized_blocks(const State *state, new_root->children[block_index++] = new_block; } - if (child->calculate_cost(dag, params, cost_model, this->options, memory_limit)) { + if (child->calculate_cost(dag, params, cost_model, this->options)) { num_children++; accept_child(std::move(child)); cache_hits++; @@ -104,9 +93,9 @@ void Cache::memoize_blocks(const FunctionDAG::Node *node, LoopNest *new_root) { for (auto &child : new_root->children) { if (child->node == node) { - LoopNest *new_block = new LoopNest; // Need const reference for copy. const LoopNest *child_ptr = child.get(); + LoopNest *new_block = new LoopNest; new_block->copy_from_including_features(*child_ptr); blocks.emplace_back(new_block); cache_misses++; diff --git a/src/autoschedulers/adams2019/Cache.h b/src/autoschedulers/adams2019/Cache.h index ded209bac1c1..b48ad2e46d39 100644 --- a/src/autoschedulers/adams2019/Cache.h +++ b/src/autoschedulers/adams2019/Cache.h @@ -31,7 +31,7 @@ namespace Autoscheduler { Important changes that caching impacts, outside of this file and Cache.cpp: - LoopNest::compute_features - If cache_features is enabled (i.e. HL_DISABLE_MEMOIZED_FEATURES!=1) then this function caches + If cache_features is enabled (i.e. disable_memoized_features==0) then this function caches the featurizations of its children, and if called again, reuses those cached featurizations. The features are saved in a LoopNest's member, std::map<> features_cache. Some features do not persist, and the FeaturesIntermediates struct (see Featurization.h) is used to cache useful @@ -48,7 +48,7 @@ namespace Autoscheduler { Computes a structural hash for use in feature caching in a LoopNest. - LoopNest::collect_producers - Collects all producers for a LoopNest for use in calculating the structural hash in + Collects all producers for a LoopNest for use in calculating the structural hash in LoopNest::compute_hash_of_producers_stored_at_root. - LoopNest::collect_stages @@ -68,12 +68,6 @@ namespace Autoscheduler { struct State; -// true unless HL_DISABLE_MEMOIZED_FEATURES=1 -bool use_memoized_features(); - -// true unless HL_DISABLE_MEMOIZED_BLOCKS=1 -bool is_memoize_blocks_enabled(); - /* Object stores caching options for autoscheduling. cache_blocks: decides if tilings are cached for decisions related to parallelizing the loops of a Func. @@ -83,10 +77,10 @@ struct CachingOptions { bool cache_blocks = false; bool cache_features = false; - static CachingOptions MakeOptionsFromEnviron() { + static CachingOptions MakeOptionsFromParams(const Adams2019Params ¶ms) { CachingOptions options; - options.cache_blocks = is_memoize_blocks_enabled(); - options.cache_features = use_memoized_features(); + options.cache_blocks = params.disable_memoized_blocks == 0; + options.cache_features = params.disable_memoized_features == 0; return options; } }; @@ -122,9 +116,8 @@ struct Cache { const FunctionDAG::Node *node, int &num_children, const FunctionDAG &dag, - const MachineParams ¶ms, - CostModel *cost_model, - int64_t memory_limit) const; + const Adams2019Params ¶ms, + CostModel *cost_model) const; // Generate tilings for a specific vector dimension and memoize them. void memoize_blocks(const FunctionDAG::Node *node, LoopNest *new_root); diff --git a/src/autoschedulers/adams2019/CostModel.h b/src/autoschedulers/adams2019/CostModel.h index 8459932c8dca..5335de1c50be 100644 --- a/src/autoschedulers/adams2019/CostModel.h +++ b/src/autoschedulers/adams2019/CostModel.h @@ -3,6 +3,7 @@ #include +#include "Featurization.h" #include "FunctionDAG.h" #include "HalideBuffer.h" #include "PerfectHashMap.h" @@ -12,7 +13,48 @@ namespace Halide { namespace Internal { namespace Autoscheduler { + typedef PerfectHashMap StageMapOfScheduleFeatures; + +struct Adams2019Params { + /** Maximum level of parallelism available. */ + int parallelism = 16; + + /** Beam size to use in the beam search. Defaults to 32. Use 1 to get a greedy search instead. + * Formerly HL_BEAM_SIZE */ + int beam_size = 32; + + /** percent chance of accepting each state in the beam. + * Normalized by the number of decisions made, so 5 would be there's a 5 percent chance of never rejecting any states. + * Formerly HL_RANDOM_DROPOUT */ + int random_dropout = 100; + + /** Random seed used by the random dropout. If 0, use time(). + * Formerly HL_SEED */ + int random_dropout_seed = 0; + + /** When training or schedule, read weights from this directory or file. + * (If path ends in `.weights` it is written as a single file, otherwise a directory of files.) + * Formerly HL_WEIGHTS_DIR */ + std::string weights_path; + + /** If set to nonzero value: limits the search space to that of Mullapudi et al. + * Formerly HL_NO_SUBTILING */ + int disable_subtiling = 0; + + /** If set to nonzero value: features of possible schedules are always recalculated, and are not cached across passes. + * Formerly HL_DISABLE_MEMOIZED_FEATURES */ + int disable_memoized_features = 0; + + /** If set to nonzero value: tiling sizes are not cached across passes. + * Formerly HL_DISABLE_MEMOIZED_BLOCKS */ + int disable_memoized_blocks = 0; + + /** If >= 0, only consider schedules that allocate at most this much memory (measured in bytes). + * Formerly HL_AUTOSCHEDULE_MEMORY_LIMIT */ + int64_t memory_limit = -1; +}; + } // namespace Autoscheduler } // namespace Internal @@ -22,7 +64,7 @@ class CostModel { // Configure the cost model for the algorithm to be scheduled. virtual void set_pipeline_features(const Internal::Autoscheduler::FunctionDAG &dag, - const MachineParams ¶ms) = 0; + const Internal::Autoscheduler::Adams2019Params ¶ms) = 0; // Enqueue a schedule to be evaluated. Will annotate the value located at cost_ptr when the evaluation takes place. // Note that the dag argument should correspond to the dag specified previously when calling set_pipeline_features. diff --git a/src/autoschedulers/adams2019/DefaultCostModel.cpp b/src/autoschedulers/adams2019/DefaultCostModel.cpp index 7d036f8a5888..7339261920bc 100644 --- a/src/autoschedulers/adams2019/DefaultCostModel.cpp +++ b/src/autoschedulers/adams2019/DefaultCostModel.cpp @@ -47,7 +47,7 @@ bool ends_with(const std::string &str, const std::string &suffix) { } // namespace void DefaultCostModel::set_pipeline_features(const Internal::Autoscheduler::FunctionDAG &dag, - const MachineParams ¶ms) { + const Internal::Autoscheduler::Adams2019Params ¶ms) { const int pipeline_feat_size = head1_w * head1_h; // We ignore the first seven pipeline features in the cost @@ -135,8 +135,8 @@ void DefaultCostModel::enqueue(const Internal::Autoscheduler::FunctionDAG &dag, void DefaultCostModel::enqueue(int ns, Runtime::Buffer *schedule_feats, double *cost_ptr) { num_stages = ns; - // We know the most stages that will ever be enqueued from the schedule features - internal_assert(pipeline_feat_queue.data() && "Call set_schedule_features before calling enqueue\n"); + // We know the most stages that will ever be enqueued from the pipeline features + internal_assert(pipeline_feat_queue.data() && "Call set_pipeline_features before calling enqueue\n"); const int max_num_stages = pipeline_feat_queue.dim(2).extent(); internal_assert(num_stages <= max_num_stages) << "schedule features has more stages (" << num_stages @@ -232,18 +232,18 @@ float DefaultCostModel::backprop(const Runtime::Buffer &true_runtim *(cost_ptrs(i)) = dst(i); if (std::isnan(dst(i))) { any_nans = true; - aslog(0) << "Prediction " << i << " is NaN. True runtime is " << true_runtimes(i) << "\n"; - aslog(0) << "Checking pipeline features for NaNs...\n"; + aslog(1) << "Prediction " << i << " is NaN. True runtime is " << true_runtimes(i) << "\n"; + aslog(1) << "Checking pipeline features for NaNs...\n"; pipeline_feat_queue.for_each_value([&](float f) { if (std::isnan(f)) abort(); }); - aslog(0) << "None found\n"; - aslog(0) << "Checking schedule features for NaNs...\n"; + aslog(1) << "None found\n"; + aslog(1) << "Checking schedule features for NaNs...\n"; schedule_feat_queue.for_each_value([&](float f) { if (std::isnan(f)) abort(); }); - aslog(0) << "None found\n"; - aslog(0) << "Checking network weights for NaNs...\n"; + aslog(1) << "None found\n"; + aslog(1) << "Checking network weights for NaNs...\n"; weights.for_each_buffer([&](const Runtime::Buffer &buf) { buf.for_each_value([&](float f) { if (std::isnan(f)) abort(); }); }); - aslog(0) << "None found\n"; + aslog(1) << "None found\n"; } internal_assert(true_runtimes(i) > 0); } @@ -350,17 +350,8 @@ void DefaultCostModel::load_weights() { } if (need_randomize) { - // Get the seed for random weights - std::string seed_str = Internal::get_env_variable("HL_RANDOM_WEIGHT_SEED"); - // Or use the time, if not set. - int seed = (int)time(nullptr); - if (!seed_str.empty()) { - std::cout << "Randomizing with HL_RANDOM_WEIGHT_SEED, seed = " << seed_str << "\n"; - seed = atoi(seed_str.c_str()); - } - else{ - std::cout << "Randomizing weights using time-based seed = " << seed << "\n"; - } + auto seed = time(nullptr); + std::cout << "Randomizing weights using seed = " << seed << "\n"; weights.randomize((uint32_t)seed); } diff --git a/src/autoschedulers/adams2019/DefaultCostModel.h b/src/autoschedulers/adams2019/DefaultCostModel.h index 11dff14ef0dc..9f7d6ac6c39b 100644 --- a/src/autoschedulers/adams2019/DefaultCostModel.h +++ b/src/autoschedulers/adams2019/DefaultCostModel.h @@ -7,6 +7,12 @@ namespace Halide { +namespace Internal { +namespace Autoscheduler { +struct Adams2019Params; +} // namespace Autoscheduler +} // namespace Internal + class DefaultCostModel : public CostModel { private: Internal::Weights weights; @@ -37,7 +43,7 @@ class DefaultCostModel : public CostModel { // Configure the cost model for the algorithm to be scheduled. void set_pipeline_features(const Internal::Autoscheduler::FunctionDAG &dag, - const MachineParams ¶ms) override; + const Internal::Autoscheduler::Adams2019Params ¶ms) override; void set_pipeline_features(const Runtime::Buffer &, int n); // Enqueue a schedule to be evaluated. The second version of this method returns a buffer of diff --git a/src/autoschedulers/adams2019/Featurization.h b/src/autoschedulers/adams2019/Featurization.h index 0e050bcc55a4..ed1ff00b81c3 100644 --- a/src/autoschedulers/adams2019/Featurization.h +++ b/src/autoschedulers/adams2019/Featurization.h @@ -5,8 +5,6 @@ #include #include -#include "ASLog.h" - namespace Halide { namespace Internal { @@ -99,8 +97,7 @@ struct PipelineFeatures { // Each row sums to 1 or 0. Each column sums to 1. f(z, y, x, 4) int slice_accesses[(int)AccessType::NumAccessTypes][(int)ScalarType::NumScalarTypes] = {}; - template - void dump(OS &os) const { + void dump(std::ostream &os) const { for (int i = 0; i < (int)ScalarType::NumScalarTypes; i++) { const char *type_names[] = {"Bool", "UInt8", "UInt16", "UInt32", "UInt64", "Float", "Double"}; // Skip printing for types not used @@ -157,10 +154,6 @@ struct PipelineFeatures { << slice_accesses[3][i] << "\n"; } } - void dump() const { - auto os = aslog(0); - dump(os); - } }; // The schedule-dependent portion of the featurization of a stage @@ -314,8 +307,7 @@ struct ScheduleFeatures { double working_set_at_realization = 0; double working_set_at_root = 0; - template - void dump(OS &os) const { + void dump(std::ostream &os) const { os << " num_realizations: " << num_realizations << "\n" << " num_productions: " << num_productions << "\n" << " points_computed_per_realization: " << points_computed_per_realization << "\n" @@ -356,10 +348,6 @@ struct ScheduleFeatures { << " working_set_at_realization: " << working_set_at_realization << "\n" << " working_set_at_root: " << working_set_at_root << "\n"; } - void dump() const { - auto os = aslog(0); - dump(os); - } bool equal(const ScheduleFeatures &other) const { const size_t n_features = ScheduleFeatures::num_features(); diff --git a/src/autoschedulers/adams2019/FunctionDAG.cpp b/src/autoschedulers/adams2019/FunctionDAG.cpp index d3587e568727..97682ca683be 100644 --- a/src/autoschedulers/adams2019/FunctionDAG.cpp +++ b/src/autoschedulers/adams2019/FunctionDAG.cpp @@ -307,42 +307,44 @@ class Featurizer : public IRVisitor { } // namespace -void LoadJacobian::dump(const char *prefix) const { +void LoadJacobian::dump(std::ostream &os, const char *prefix) const { if (count() > 1) { - aslog(0) << prefix << count() << " x\n"; + os << prefix << count() << " x\n"; } for (size_t i = 0; i < producer_storage_dims(); i++) { - aslog(0) << prefix << " ["; + os << prefix << " ["; for (size_t j = 0; j < consumer_loop_dims(); j++) { const auto &c = (*this)(i, j); if (!c.exists) { - aslog(0) << " _ "; + os << " _ "; } else if (c.denominator == 1) { - aslog(0) << " " << c.numerator << " "; + os << " " << c.numerator << " "; } else { - aslog(0) << c.numerator << "/" << c.denominator << " "; + os << c.numerator << "/" << c.denominator << " "; } } - aslog(0) << "]\n"; + os << "]\n"; } - aslog(0) << "\n"; + os << "\n"; } void BoundContents::validate() const { for (int i = 0; i < layout->total_size; i++) { auto p = data()[i]; if (p.max() < p.min()) { - aslog(0) << "Bad bounds object:\n"; + std::ostringstream err; + err << "Bad bounds object:\n"; for (int j = 0; j < layout->total_size; j++) { if (i == j) { - aslog(0) << "=> "; + err << "=> "; } else { - aslog(0) << " "; + err << " "; } - aslog(0) << j << ": " << data()[j].min() << ", " << data()[j].max() << "\n"; + err << j << ": " << data()[j].min() << ", " << data()[j].max() << "\n"; } - internal_error << "Aborting"; + err << "Aborting"; + internal_error << err.str(); } } } @@ -570,7 +572,7 @@ bool depends_on_estimate(const Expr &expr) { return dependency_checker.found_estimate; } -FunctionDAG::FunctionDAG(const vector &outputs, const MachineParams ¶ms, const Target &target) { +FunctionDAG::FunctionDAG(const vector &outputs, const Target &target) { map env = build_environment(outputs); // A mutator to apply parameter estimates to the expressions @@ -645,7 +647,7 @@ FunctionDAG::FunctionDAG(const vector &outputs, const MachineParams &p for (int s = 0; s <= (int)consumer.updates().size(); s++) { auto &stage = node.stages[s]; stage.node = &node; - stage.name = conform_name(consumer.name()); + stage.name = consumer.name(); if (s > 0) { stage.name += ".update(" + std::to_string(s - 1) + ")"; } @@ -727,7 +729,6 @@ FunctionDAG::FunctionDAG(const vector &outputs, const MachineParams &p Node::Loop l; l.var = d.var; l.accessor = stage.name + ".get_schedule().dims()[" + std::to_string(i) + "].var"; - l.python_accessor = stage.name + ".get_schedule_dim_var_name(" + std::to_string(i) + ")"; // Python index same as C++ // We already have the right variable names in the stage scope Interval in = stage_scope_with_concrete_rvar_bounds.get(l.var); @@ -817,6 +818,11 @@ FunctionDAG::FunctionDAG(const vector &outputs, const MachineParams &p check_type(op->type); } + void visit(const Reinterpret *op) override { + IRVisitor::visit(op); + check_type(op->type); + } + void check_type(Type t) { if (t.bits() > 1 && (!narrowest_type.bits() || @@ -953,7 +959,7 @@ FunctionDAG::FunctionDAG(const vector &outputs, const MachineParams &p } node.is_wrapper = node.func.is_wrapper(); - node.is_input = !node.func.has_update_definition() && node.is_wrapper && !any_incoming_edges; + node.is_input = !node.is_output && !node.func.has_update_definition() && node.is_wrapper && !any_incoming_edges; node.dimensions = node.func.dimensions(); } } @@ -1040,8 +1046,7 @@ void FunctionDAG::featurize() { } } -template -void FunctionDAG::dump_internal(OS &os) const { +void FunctionDAG::dump(std::ostream &os) const { for (const Node &n : nodes) { os << "Node: " << n.func.name() << "\n" << " Symbolic region required: \n"; @@ -1077,21 +1082,11 @@ void FunctionDAG::dump_internal(OS &os) const { os << " Load Jacobians:\n"; for (const auto &jac : e.load_jacobians) { - jac.dump(" "); + jac.dump(os, " "); } } } -void FunctionDAG::dump() const { - auto os = aslog(0); - dump_internal(os); -} - -std::ostream &FunctionDAG::dump(std::ostream &os) const { - dump_internal(os); - return os; -} - } // namespace Autoscheduler } // namespace Internal } // namespace Halide diff --git a/src/autoschedulers/adams2019/FunctionDAG.h b/src/autoschedulers/adams2019/FunctionDAG.h index b09918940517..75c75c3c8b07 100644 --- a/src/autoschedulers/adams2019/FunctionDAG.h +++ b/src/autoschedulers/adams2019/FunctionDAG.h @@ -27,6 +27,8 @@ using std::string; using std::unique_ptr; using std::vector; +struct Adams2019Params; + // First we have various utility classes. // An optional rational type used when analyzing memory dependencies. @@ -205,7 +207,7 @@ class LoadJacobian { return result; } - void dump(const char *prefix) const; + void dump(std::ostream &os, const char *prefix) const; }; // Classes to represent a concrete set of bounds for a Func. A Span is @@ -438,7 +440,6 @@ struct FunctionDAG { // from its owner Func. Used for printing source code // equivalent to a computed schedule. string accessor; - string python_accessor; // Same as above for Python schedules }; // Get the loop nest shape as a function of the region computed @@ -564,18 +565,14 @@ struct FunctionDAG { // Create the function DAG, and do all the dependency and cost // analysis. This is done once up-front before the tree search. - FunctionDAG(const vector &outputs, const MachineParams ¶ms, const Target &target); + FunctionDAG(const vector &outputs, const Target &target); - void dump() const; - std::ostream &dump(std::ostream &os) const; + void dump(std::ostream &os) const; private: // Compute the featurization for the entire DAG void featurize(); - template - void dump_internal(OS &os) const; - public: // This class uses a lot of internal pointers, so we'll make it uncopyable/unmovable. FunctionDAG(const FunctionDAG &other) = delete; diff --git a/src/autoschedulers/adams2019/LoopNest.cpp b/src/autoschedulers/adams2019/LoopNest.cpp index 309903f2c416..d6cc7e6058e1 100644 --- a/src/autoschedulers/adams2019/LoopNest.cpp +++ b/src/autoschedulers/adams2019/LoopNest.cpp @@ -13,20 +13,6 @@ namespace Autoscheduler { // registers. const int kUnrollLimit = 12; -// Get the HL_NO_SUBTILING environment variable. Purpose described above. -bool get_may_subtile() { - string no_subtiling_str = get_env_variable("HL_NO_SUBTILING"); - if (no_subtiling_str == "1") { - return false; - } else { - return true; - } -} -bool may_subtile() { - static bool b = get_may_subtile(); - return b; -} - // Given a multi-dimensional box of dimensionality d, generate a list // of candidate tile sizes for it, logarithmically spacing the sizes // using the given factor. If 'allow_splits' is false, every dimension @@ -227,7 +213,7 @@ void LoopNest::get_sites(StageMap &sites, // Do a recursive walk over the loop nest computing features to feed the cost model. void LoopNest::compute_features(const FunctionDAG &dag, - const MachineParams ¶ms, + const Adams2019Params ¶ms, const StageMap &sites, int64_t instances, int64_t parallelism, @@ -729,7 +715,6 @@ void LoopNest::compute_features(const FunctionDAG &dag, int64_t footprint = e->producer->bytes_per_point; int64_t compute_footprint = footprint; int64_t store_footprint = footprint; - int64_t task_footprint = footprint; int64_t line_footprint = 1; int64_t compute_line_footprint = 1; int64_t store_line_footprint = 1; @@ -890,7 +875,6 @@ void LoopNest::compute_features(const FunctionDAG &dag, footprint *= extent; compute_footprint *= compute_extent; store_footprint *= store_extent; - task_footprint *= task_extent; bool dense = ((e->producer->is_input && i == 0) || (site.produce != nullptr && i == site.produce->vector_dim)); @@ -1093,79 +1077,24 @@ const Bound &LoopNest::get_bounds(const FunctionDAG::Node *f) const { } // Recursively print a loop nest representation to stderr -void LoopNest::dump(string prefix, const LoopNest *parent) const { +void LoopNest::dump(std::ostream &os, string prefix, const LoopNest *parent) const { if (!is_root()) { // Non-root nodes always have parents. internal_assert(parent != nullptr); - aslog(0) << prefix << node->func.name(); + os << prefix << node->func.name(); prefix += " "; for (size_t i = 0; i < size.size(); i++) { - aslog(0) << " " << size[i]; - // The vectorized loop gets a 'v' suffix - if (innermost && i == (size_t)vectorized_loop_index) { - aslog(0) << "v"; - } - // Loops that have a known constant size get a - // 'c'. Useful for knowing what we can unroll. - if (parent->get_bounds(node)->loops(stage->index, i).constant_extent()) { - aslog(0) << "c"; - } - } - - // Uncomment when debugging the representative loop bounds selected. - /* - const auto &bounds = get_bounds(node); - for (size_t i = 0; i < size.size(); i++) { - const auto &p = bounds->loops(stage->index, i); - aslog(0) << " [" << p.min() << ", " << p.max() << "]"; - } - */ - - aslog(0) << " (" << vectorized_loop_index << ", " << vector_dim << ")"; - } - - if (tileable) { - aslog(0) << " t"; - } - if (innermost) { - aslog(0) << " *\n"; - } else if (parallel) { - aslog(0) << " p\n"; - } else { - aslog(0) << "\n"; - } - for (const auto *p : store_at) { - aslog(0) << prefix << "realize: " << p->func.name() << "\n"; - } - for (size_t i = children.size(); i > 0; i--) { - children[i - 1]->dump(prefix, this); - } - for (auto it = inlined.begin(); it != inlined.end(); it++) { - aslog(0) << prefix << "inlined: " << it.key()->func.name() << " " << it.value() << "\n"; - } -} - -string LoopNest::dump(string prefix, const LoopNest *parent, bool dummy) const { - string result; - - if (!is_root()) { - // Non-root nodes always have parents. - internal_assert(parent != nullptr); - result += prefix; - prefix += ">"; - - for (size_t i = 0; i < size.size(); i++) { - result += (std::to_string(size[i]) + ","); + os << " " << size[i]; // The vectorized loop gets a 'v' suffix if (innermost && i == (size_t)vectorized_loop_index) { - result += "v,"; + os << "v"; } // Loops that have a known constant size get a // 'c'. Useful for knowing what we can unroll. if (parent->get_bounds(node)->loops(stage->index, i).constant_extent()) { - result += "c,"; + os << "c"; } } @@ -1174,33 +1103,32 @@ string LoopNest::dump(string prefix, const LoopNest *parent, bool dummy) const { const auto &bounds = get_bounds(node); for (size_t i = 0; i < size.size(); i++) { const auto &p = bounds->loops(stage->index, i); - aslog(0) << " [" << p.min() << ", " << p.max() << "]"; + os << " [" << p.first << ", " << p.second << "]"; } */ - result += std::to_string(vectorized_loop_index) + "," + std::to_string(vector_dim) + ","; + os << " (" << vectorized_loop_index << ", " << vector_dim << ")"; } if (tileable) { - result += "t,"; + os << " t"; } if (innermost) { - result += "i,\n"; + os << " *\n"; } else if (parallel) { - result += "p,\n"; + os << " p\n"; } else { - result += "\n"; + os << "\n"; } for (const auto *p : store_at) { - //aslog(0) << prefix << "realize: " << p->func.name() << "\n"; + os << prefix << "realize: " << p->func.name() << "\n"; } for (size_t i = children.size(); i > 0; i--) { - result += children[i - 1]->dump(prefix, this, true); + children[i - 1]->dump(os, prefix, this); } for (auto it = inlined.begin(); it != inlined.end(); it++) { - //aslog(0) << prefix << "inlined: " << it.key()->func.name() << " " << it.value() << "\n"; + os << prefix << "inlined: " << it.key()->func.name() << " " << it.value() << "\n"; } - return result; } // Does this loop nest access the given Func @@ -1302,7 +1230,7 @@ void LoopNest::inline_func(const FunctionDAG::Node *f) { // Inline it into the children for (auto &child : children) { if (child->calls(f)) { - std::unique_ptr new_child{new LoopNest}; + auto new_child = std::make_unique(); new_child->copy_from(*child); new_child->inline_func(f); child = new_child.release(); @@ -1327,10 +1255,11 @@ void LoopNest::inline_func(const FunctionDAG::Node *f) { } // Compute a Func at this site. -void LoopNest::compute_here(const FunctionDAG::Node *f, bool tileable, int v) { +void LoopNest::compute_here(const FunctionDAG::Node *f, bool tileable, int v, const Adams2019Params ¶ms) { const auto &bounds = get_bounds(f); - if (!may_subtile()) { + const bool may_subtile = (params.disable_subtiling != 0); + if (!may_subtile) { // If we are restricting ourselves to the Mullapudi et al // scheduling space, then once something is computed here // we may not subtile this loop. @@ -1343,7 +1272,7 @@ void LoopNest::compute_here(const FunctionDAG::Node *f, bool tileable, int v) { node->stage = &f->stages[s]; node->innermost = true; node->vectorized_loop_index = -1; - node->tileable = tileable && (is_root() || may_subtile()); + node->tileable = tileable && (is_root() || may_subtile); // Set up a bound for the inside of the // loop. computed/required is still the full region, but // the loop nest will be a single representative point. @@ -1351,13 +1280,11 @@ void LoopNest::compute_here(const FunctionDAG::Node *f, bool tileable, int v) { size_t loop_dim = f->stages[s].loop.size(); node->size.resize(loop_dim); - int64_t total_extent = 1; int64_t vector_size = 1; for (size_t i = 0; i < loop_dim; i++) { const auto &l = bounds->loops(s, i); // Initialize the loop nest node->size[i] = l.extent(); - total_extent *= node->size[i]; // Use the first loop iteration to represent the inner // loop. We'll shift it to a later one once we decide @@ -1415,21 +1342,23 @@ void LoopNest::compute_here(const FunctionDAG::Node *f, bool tileable, int v) { } // Parallelize this loop according to the given tiling. -IntrusivePtr LoopNest::parallelize_in_tiles(const MachineParams ¶ms, +IntrusivePtr LoopNest::parallelize_in_tiles(const Adams2019Params ¶ms, const vector &tiling, const LoopNest *parent) const { + const bool may_subtile = (params.disable_subtiling != 0); + // Split this loop and move factors to the inner loop LoopNest *inner = new LoopNest, *outer = new LoopNest; inner->node = outer->node = node; inner->stage = outer->stage = stage; - inner->tileable = outer->tileable = tileable && may_subtile(); + inner->tileable = outer->tileable = tileable && may_subtile; inner->vector_dim = outer->vector_dim = vector_dim; inner->vectorized_loop_index = outer->vectorized_loop_index = vectorized_loop_index; outer->size = size; outer->innermost = false; outer->parallel = true; - outer->tileable = may_subtile(); + outer->tileable = may_subtile; // First make an inner loop representing a 1x1x1... tile inner->size.resize(size.size(), 1); @@ -1483,9 +1412,11 @@ IntrusivePtr LoopNest::parallelize_in_tiles(const MachineParams // this loop nest. vector> LoopNest::compute_in_tiles(const FunctionDAG::Node *f, const LoopNest *parent, - const MachineParams ¶ms, + const Adams2019Params ¶ms, int v, bool in_realization) const { + const bool may_subtile = (params.disable_subtiling != 0); + internal_assert(f); vector> result; @@ -1540,9 +1471,9 @@ vector> LoopNest::compute_in_tiles(const FunctionDA vector_dim == -1 || size[vector_dim] == 1)) { - std::unique_ptr r{new LoopNest}; + auto r = std::make_unique(); r->copy_from(*this); - r->compute_here(f, true, v); + r->compute_here(f, true, v, params); if (!in_realization) { r->store_at.insert(f); } else { @@ -1564,7 +1495,7 @@ vector> LoopNest::compute_in_tiles(const FunctionDA auto tilings = generate_tilings(size, (int)(size.size() - 1), 2, !in_realization); if (tilings.size() > 10000) { - aslog(0) << "Warning: lots of tilings: " << tilings.size() << "\n"; + aslog(1) << "Warning: lots of tilings: " << tilings.size() << "\n"; } for (auto t : tilings) { @@ -1595,7 +1526,7 @@ vector> LoopNest::compute_in_tiles(const FunctionDA LoopNest *inner = new LoopNest, *outer = new LoopNest; inner->node = outer->node = node; inner->stage = outer->stage = stage; - inner->tileable = outer->tileable = tileable && may_subtile(); + inner->tileable = outer->tileable = tileable && may_subtile; inner->vector_dim = outer->vector_dim = vector_dim; inner->vectorized_loop_index = outer->vectorized_loop_index = vectorized_loop_index; outer->size = size; @@ -1665,14 +1596,14 @@ vector> LoopNest::compute_in_tiles(const FunctionDA } // Site the computation inside the outer loop - outer->compute_here(f, true, v); + outer->compute_here(f, true, v, params); outer->tileable &= !in_realization; result.emplace_back(outer); } } if (child >= 0 && !called_by_multiple_children && !in_realization && - (may_subtile() || is_root())) { + (may_subtile || is_root())) { // Push the Func further inwards in the loop nest // See if it's appropriate to slide over this loop Can't @@ -1740,7 +1671,6 @@ void LoopNest::apply(LoopLevel here, if (c->stage->index == 0) { auto &state = state_map.get(c->stage); state->schedule_source << "\n .compute_root()"; - state->python_schedule_source << " \\\n .compute_root()"; // TODO: Omitting logic for printing store_root() assumes everything store_root is also compute root } } @@ -1765,7 +1695,6 @@ void LoopNest::apply(LoopLevel here, fv.var = VarOrRVar(l.var, !l.pure); fv.orig = fv.var; fv.accessor = l.accessor; - fv.python_accessor = l.python_accessor; const auto &p = parent_bounds->loops(stage->index, i); fv.extent = p.extent(); fv.constant_extent = p.constant_extent(); @@ -1806,7 +1735,6 @@ void LoopNest::apply(LoopLevel here, // or stack as it likes. Func(node->func).store_in(MemoryType::Stack); state.schedule_source << "\n .store_in(MemoryType::Stack)"; - state.python_schedule_source << " \\\n .store_in(hl.MemoryType.Stack)"; } } @@ -1840,9 +1768,7 @@ void LoopNest::apply(LoopLevel here, internal_assert(v.innermost_pure_dim && v.exists) << v.var.name() << "\n"; // Is the result of a split state.schedule_source - << "\n .vectorize(" << conform_name(v.var.name()) << ")"; - state.python_schedule_source - << " \\\n .vectorize(" << conform_name(v.var.name()) << ")"; + << "\n .vectorize(" << v.var.name() << ")"; s.vectorize(v.var); } } else { @@ -1893,9 +1819,9 @@ void LoopNest::apply(LoopLevel here, parent.exists = false; parent.extent = 1; } else { - VarOrRVar inner(Var(conform_name(parent.var.name() + "i"))); + VarOrRVar inner(Var(parent.var.name() + "i")); if (parent.var.is_rvar) { - inner = RVar(conform_name(parent.var.name() + "i", "r")); + inner = RVar(parent.var.name() + "i"); } auto tail_strategy = pure_var_tail_strategy; @@ -1912,24 +1838,16 @@ void LoopNest::apply(LoopLevel here, s.split(parent.var, parent.var, inner, (int)factor, tail_strategy); state.schedule_source << "\n .split(" - << conform_name(parent.var.name()) << ", " - << conform_name(parent.var.name()) << ", " + << parent.var.name() << ", " + << parent.var.name() << ", " << inner.name() << ", " << factor << ", " << "TailStrategy::" << tail_strategy << ")"; - state.python_schedule_source - << " \\\n .split(" - << conform_name(parent.var.name()) << ", " - << conform_name(parent.var.name()) << ", " - << inner.name() << ", " - << factor << ", " - << "hl.TailStrategy." << tail_strategy << ")"; v = parent; parent.extent = size[parent.index]; v.constant_extent = (tail_strategy != TailStrategy::GuardWithIf); v.var = inner; v.accessor.clear(); - v.python_accessor.clear(); v.extent = factor; v.parallel = false; v.outermost = false; @@ -1960,8 +1878,7 @@ void LoopNest::apply(LoopLevel here, for (size_t i = 0; i < symbolic_loop.size(); i++) { if (state.vars[i].pure && state.vars[i].exists && state.vars[i].extent > 1) { s.unroll(state.vars[i].var); - state.schedule_source << "\n .unroll(" << conform_name(state.vars[i].var.name()) << ")"; - state.python_schedule_source << " \\\n .unroll(" << conform_name(state.vars[i].var.name()) << ")"; + state.schedule_source << "\n .unroll(" << state.vars[i].var.name() << ")"; } } } @@ -2000,7 +1917,7 @@ void LoopNest::apply(LoopLevel here, if (here.is_root()) { loop_level = "_root()"; } else { - loop_level = "_at(" + conform_name(here.func()) + ", " + conform_name(here.var().name()) + ")"; + loop_level = "_at(" + here.func() + ", " + here.var().name() + ")"; } for (const auto &c : children) { if (c->node != node) { @@ -2010,7 +1927,6 @@ void LoopNest::apply(LoopLevel here, if (c->node != node && c->stage->index == 0) { auto &state = *(state_map.get(c->stage)); state.schedule_source << "\n .compute" << loop_level; - state.python_schedule_source << " \\\n .compute" << loop_level; } } for (const auto *f : store_at) { @@ -2024,7 +1940,6 @@ void LoopNest::apply(LoopLevel here, if (!computed_here) { auto &state = *(state_map.get(&(f->stages[0]))); state.schedule_source << "\n .store" << loop_level; - state.python_schedule_source << " \\\n .store" << loop_level; } } } diff --git a/src/autoschedulers/adams2019/LoopNest.h b/src/autoschedulers/adams2019/LoopNest.h index 9dba8783210f..afc707f9602a 100644 --- a/src/autoschedulers/adams2019/LoopNest.h +++ b/src/autoschedulers/adams2019/LoopNest.h @@ -23,8 +23,6 @@ using NodeMap = PerfectHashMap; template using StageMap = PerfectHashMap; -bool may_subtile(); - // Given a multi-dimensional box of dimensionality d, generate a list // of candidate tile sizes for it, logarithmically spacing the sizes // using the given factor. If 'allow_splits' is false, every dimension @@ -129,7 +127,7 @@ struct LoopNest { // Do a recursive walk over the loop nest computing features to feed the cost model. void compute_features(const FunctionDAG &dag, - const MachineParams ¶ms, + const Adams2019Params ¶ms, const StageMap &sites, int64_t instances, int64_t parallelism, @@ -157,10 +155,7 @@ struct LoopNest { const Bound &get_bounds(const FunctionDAG::Node *f) const; // Recursively print a loop nest representation to stderr - void dump(string prefix, const LoopNest *parent) const; - - // Recursively collect dump info above into a feature vector - string dump(string prefix, const LoopNest *parent, bool dummy) const; + void dump(std::ostream &os, string prefix, const LoopNest *parent) const; // Does this loop nest access the given Func bool calls(const FunctionDAG::Node *f) const; @@ -189,10 +184,10 @@ struct LoopNest { void inline_func(const FunctionDAG::Node *f); // Compute a Func at this site. - void compute_here(const FunctionDAG::Node *f, bool tileable, int v); + void compute_here(const FunctionDAG::Node *f, bool tileable, int v, const Adams2019Params ¶ms); // Parallelize this loop according to the given tiling. - IntrusivePtr parallelize_in_tiles(const MachineParams ¶ms, + IntrusivePtr parallelize_in_tiles(const Adams2019Params ¶ms, const vector &tiling, const LoopNest *parent) const; @@ -200,7 +195,7 @@ struct LoopNest { // this loop nest. std::vector> compute_in_tiles(const FunctionDAG::Node *f, const LoopNest *parent, - const MachineParams ¶ms, + const Adams2019Params ¶ms, int v, bool in_realization) const; @@ -230,7 +225,6 @@ struct LoopNest { // Source code to access this Var/RVar. Used for printing // valid Halide source for this schedule. string accessor; - string python_accessor; // same as above for Python schedules // Our estimate of the extent of this var. This is exact // when constant_extent flag is true. @@ -256,7 +250,6 @@ struct LoopNest { std::vector vars; std::ostringstream schedule_source; - std::ostringstream python_schedule_source; }; // Apply the schedule represented by this loop nest to a Halide pipeline. diff --git a/src/autoschedulers/adams2019/Makefile b/src/autoschedulers/adams2019/Makefile index 050bce258ebd..25498da3ac5d 100644 --- a/src/autoschedulers/adams2019/Makefile +++ b/src/autoschedulers/adams2019/Makefile @@ -52,13 +52,18 @@ $(BIN)/auto_schedule_runtime.a: $(BIN)/cost_model.generator $(BIN)/cost_model/%.a: $(BIN)/cost_model.generator @mkdir -p $(@D) - $^ -g $* -o $(BIN)/cost_model -f $* target=$(HL_TARGET)-no_runtime auto_schedule=false -e stmt,static_library,h,assembly + $^ -g $* -o $(BIN)/cost_model -f $* target=$(HL_TARGET)-no_runtime -e stmt,static_library,h,assembly # It's important to use dynamic lookups for undefined symbols here: all of libHalide # is expected to be present (in the loading binary), so we explicitly make the symbols # undefined rather than dependent on libHalide.so. -$(BIN)/libautoschedule_adams2019.$(SHARED_EXT): $(SRC)/AutoSchedule.cpp \ - $(SRC)/ASLog.cpp \ +# +# Also, be sure *not* to include libHalide in the link steps here; that can cause misbehavior +# on OSX systems in certain situations -- note that $(LIB_HALIDE) is an order-only dep, +# to ensure that (eg) Halide.h is built before this. +$(BIN)/libautoschedule_adams2019.$(PLUGIN_EXT): \ + $(COMMON_DIR)/ASLog.cpp \ + $(SRC)/AutoSchedule.cpp \ $(SRC)/Cache.h \ $(SRC)/Cache.cpp \ $(SRC)/DefaultCostModel.h \ @@ -77,13 +82,13 @@ $(BIN)/libautoschedule_adams2019.$(SHARED_EXT): $(SRC)/AutoSchedule.cpp \ $(SRC)/PerfectHashMap.h \ $(AUTOSCHED_WEIGHT_OBJECTS) \ $(AUTOSCHED_COST_MODEL_LIBS) \ - $(GENERATOR_DEPS) \ - $(BIN)/auto_schedule_runtime.a + $(BIN)/auto_schedule_runtime.a \ + | $(LIB_HALIDE) @mkdir -p $(@D) $(CXX) -shared $(USE_EXPORT_DYNAMIC) -fPIC -fvisibility=hidden -fvisibility-inlines-hidden $(CXXFLAGS) $(OPTIMIZE) -I $(BIN)/cost_model $(filter-out %.h $(LIBHALIDE_LDFLAGS),$^) -o $@ $(HALIDE_SYSTEM_LIBS) $(HALIDE_RPATH_FOR_LIB) $(BIN)/retrain_cost_model: $(SRC)/retrain_cost_model.cpp \ - $(SRC)/ASLog.cpp \ + $(COMMON_DIR)/ASLog.cpp \ $(SRC)/DefaultCostModel.h \ $(SRC)/DefaultCostModel.cpp \ $(SRC)/Weights.h \ @@ -107,11 +112,6 @@ $(BIN)/weightsdir_to_weightsfile: $(SRC)/weightsdir_to_weightsfile.cpp $(SRC)/We @mkdir -p $(@D) $(CXX) $(CXXFLAGS) $^ $(OPTIMIZE) -o $@ -# This is the value that machine_params defaults to if no custom value is specified; -# see MachineParams::generic() -HL_MACHINE_PARAMS ?= 32,25165824,160 - - # A sample generator to autoschedule. Note that if it statically links # to libHalide, then it must be build with $(USE_EXPORT_DYNAMIC), or the # autoscheduler can't find the libHalide symbols that it needs. @@ -120,22 +120,25 @@ $(GENERATOR_BIN)/demo.generator: $(SRC)/demo_generator.cpp $(GENERATOR_DEPS) $(CXX) $(CXXFLAGS) $(USE_EXPORT_DYNAMIC) -g $(filter %.cpp,$^) -o $@ $(LIBHALIDE_LDFLAGS) # To use the autoscheduler, set a few environment variables and use the -p flag to the generator to load the autoscheduler as a plugin -$(BIN)/%/demo.a: $(GENERATOR_BIN)/demo.generator $(BIN)/libautoschedule_adams2019.$(SHARED_EXT) +$(BIN)/%/demo.a: $(GENERATOR_BIN)/demo.generator $(BIN)/libautoschedule_adams2019.$(PLUGIN_EXT) @mkdir -p $(@D) - HL_WEIGHTS_DIR=$(SRC)/baseline.weights \ - $(GENERATOR_BIN)/demo.generator -g demo -o $(@D) -f demo target=$* auto_schedule=true -p $(BIN)/libautoschedule_adams2019.$(SHARED_EXT) -s Adams2019 + $(GENERATOR_BIN)/demo.generator -g demo -o $(@D) -f demo target=$* \ + autoscheduler=Adams2019 \ + autoscheduler.parallelism=32 \ + autoscheduler.weights_path=$(SRC)/baseline.weights \ + -p $(BIN)/libautoschedule_adams2019.$(PLUGIN_EXT) $(BIN)/%/demo.rungen: $(BIN)/%/RunGenMain.o $(BIN)/%/demo.registration.cpp $(BIN)/%/demo.a @mkdir -p $(@D) $(CXX) $(CXXFLAGS) -I$(BIN)/$* $^ -o $@ $(HALIDE_SYSTEM_LIBS) $(IMAGE_IO_FLAGS) # demonstrates single-shot use of the autoscheduler -demo: $(BIN)/$(HL_TARGET)/demo.rungen $(BIN)/libautoschedule_adams2019.$(SHARED_EXT) +demo: $(BIN)/$(HL_TARGET)/demo.rungen $(BIN)/libautoschedule_adams2019.$(PLUGIN_EXT) $< --benchmarks=all --benchmark_min_time=1 --estimate_all # demonstrates an autotuning loop # (using $(BIN) and $(SRC) here seems overkill, but makes copy-n-paste elsewhere easier) -autotune: $(GENERATOR_BIN)/demo.generator $(BIN)/featurization_to_sample $(BIN)/get_host_target $(BIN)/retrain_cost_model $(BIN)/libautoschedule_adams2019.$(SHARED_EXT) $(SRC)/autotune_loop.sh +autotune: $(GENERATOR_BIN)/demo.generator $(BIN)/featurization_to_sample $(BIN)/get_host_target $(BIN)/retrain_cost_model $(BIN)/libautoschedule_adams2019.$(PLUGIN_EXT) $(SRC)/autotune_loop.sh @mkdir -p $(@D) bash $(SRC)/autotune_loop.sh \ $(GENERATOR_BIN)/demo.generator \ @@ -150,14 +153,14 @@ $(BIN)/test_perfect_hash_map: $(SRC)/test_perfect_hash_map.cpp $(SRC)/PerfectHas @mkdir -p $(@D) $(CXX) $(CXXFLAGS) $< -o $@ -$(BIN)/test_function_dag: $(SRC)/test_function_dag.cpp $(SRC)/FunctionDAG.h $(SRC)/FunctionDAG.cpp $(SRC)/ASLog.h $(SRC)/ASLog.cpp +$(BIN)/test_function_dag: $(SRC)/test_function_dag.cpp $(SRC)/FunctionDAG.h $(SRC)/FunctionDAG.cpp $(COMMON_DIR)/ASLog.h $(COMMON_DIR)/ASLog.cpp @mkdir -p $(@D) $(CXX) $(CXXFLAGS) $(USE_EXPORT_DYNAMIC) $(filter-out %.h,$^) -o $@ $(LIBHALIDE_LDFLAGS) $(HALIDE_SYSTEM_LIBS) # Simple jit-based test -$(BIN)/%/test: $(SRC)/test.cpp $(BIN)/libautoschedule_adams2019.$(SHARED_EXT) +$(BIN)/%/test: $(SRC)/test.cpp $(BIN)/libautoschedule_adams2019.$(PLUGIN_EXT) @mkdir -p $(@D) - $(CXX) $(CXXFLAGS) $(USE_EXPORT_DYNAMIC) $^ -o $@ $(LIBHALIDE_LDFLAGS) $(HALIDE_SYSTEM_LIBS) + $(CXX) $(CXXFLAGS) $(USE_EXPORT_DYNAMIC) $< -o $@ $(LIBHALIDE_LDFLAGS) $(HALIDE_SYSTEM_LIBS) test_perfect_hash_map: $(BIN)/test_perfect_hash_map $^ @@ -166,7 +169,7 @@ test_function_dag: $(BIN)/test_function_dag $^ run_test: $(BIN)/$(HL_TARGET)/test - HL_WEIGHTS_DIR=$(SRC)/baseline.weights LD_LIBRARY_PATH=$(BIN):$(LD_LIBRARY_PATH) $< $(BIN)/libautoschedule_adams2019.$(SHARED_EXT) + LD_LIBRARY_PATH=$(BIN):$(LD_LIBRARY_PATH) $< $(BIN)/libautoschedule_adams2019.$(PLUGIN_EXT) $(SRC)/baseline.weights .PHONY: test clean @@ -181,7 +184,7 @@ build: $(BIN)/$(HL_TARGET)/test \ $(BIN)/featurization_to_sample \ $(BIN)/get_host_target \ $(BIN)/retrain_cost_model \ - $(BIN)/libautoschedule_adams2019.$(SHARED_EXT) + $(BIN)/libautoschedule_adams2019.$(PLUGIN_EXT) test: run_test test_perfect_hash_map test_function_dag demo test_included_schedule_file autotune @@ -204,10 +207,13 @@ $(GENERATOR_BIN)/included_schedule_file_none.generator: $(SRC)/included_schedule # This is the target you build to (re)generate the schedule file. # (Note that we only need the schedule output, so we pass `-e schedule` to # the Generator so that it can skip producing other outputs.) -$(BIN)/%/included_schedule_file.schedule.h: $(GENERATOR_BIN)/included_schedule_file_none.generator $(BIN)/libautoschedule_adams2019.$(SHARED_EXT) +$(BIN)/%/included_schedule_file.schedule.h: $(GENERATOR_BIN)/included_schedule_file_none.generator $(BIN)/libautoschedule_adams2019.$(PLUGIN_EXT) @mkdir -p $(@D) - HL_WEIGHTS_DIR=$(SRC)/baseline.weights \ - $< -g included_schedule_file -o $(@D) -f included_schedule_file target=$* auto_schedule=true -p $(BIN)/libautoschedule_adams2019.$(SHARED_EXT) -s Adams2019 -e schedule + $< -g included_schedule_file -o $(@D) -f included_schedule_file target=$* \ + autoscheduler=Adams2019 \ + autoscheduler.parallelism=32 \ + autoscheduler.weights_path=$(SRC)/baseline.weights \ + -p $(BIN)/libautoschedule_adams2019.$(PLUGIN_EXT) -e schedule # Note that this depends on included_schedule_file.schedule.h rather than $(BIN)/%/included_schedule_file.schedule.h -- # the former should be generated by something like diff --git a/src/autoschedulers/adams2019/State.cpp b/src/autoschedulers/adams2019/State.cpp index 2fde5974e7f9..6cd994a0625c 100644 --- a/src/autoschedulers/adams2019/State.cpp +++ b/src/autoschedulers/adams2019/State.cpp @@ -14,7 +14,7 @@ uint64_t State::structural_hash(int depth) const { return h; } -void State::compute_featurization(const FunctionDAG &dag, const MachineParams ¶ms, +void State::compute_featurization(const FunctionDAG &dag, const Adams2019Params ¶ms, StageMap *features, const CachingOptions &cache_options) { StageMap sites; sites.make_large(dag.nodes[0].stages[0].max_id); @@ -54,10 +54,10 @@ void State::compute_featurization(const FunctionDAG &dag, const MachineParams &p l = consumer_site.compute; } if (!l) { - if (aslog::aslog_level() > 0) { - dump(); - } - internal_error << e->producer->func.name() << " -> " << e->consumer->name << "\n"; + std::ostringstream err; + dump(err); + err << e->producer->func.name() << " -> " << e->consumer->name << "\n"; + internal_error << err.str(); } if (loop) { loop = deepest_common_ancestor(parent, l, loop); @@ -93,45 +93,17 @@ void State::compute_featurization(const FunctionDAG &dag, const MachineParams &p } } -void State::save_featurization(const FunctionDAG &dag, const MachineParams ¶ms, - const CachingOptions &cache_options, std::ostream &out, std::ostream &index_out) { +void State::save_featurization(const FunctionDAG &dag, const Adams2019Params ¶ms, + const CachingOptions &cache_options, std::ostream &out) { StageMap features; compute_featurization(dag, params, &features, cache_options); - index_out << "{ \"feature_stage_index\":\n[\n"; - int offset = 0; - - int last_node_id = 0, last_stage_id = 0; - for (const auto &n : dag.nodes) { - if (n.is_input) { - continue; - } - - for (size_t stage_idx = n.stages.size(); stage_idx > 0; stage_idx--) { - const auto &s = n.stages[stage_idx - 1]; - last_node_id = n.id; - last_stage_id = s.id; - } - } - for (const auto &n : dag.nodes) { if (n.is_input) { continue; } - for (size_t stage_idx = n.stages.size(); stage_idx > 0; stage_idx--) { const auto &s = n.stages[stage_idx - 1]; - index_out << " {\n"; - index_out << " \"node_id\": " << n.id << ",\n"; - index_out << " \"node_name\": \"" << n.func.name() << "\",\n"; - index_out << " \"stage_id\": " << s.id << ",\n"; - index_out << " \"stage_name\": \"" << s.name << "\",\n"; - index_out << " \"stage_offset\": " << offset << "\n"; - if (last_node_id == n.id && last_stage_id == s.id) { - index_out << " }\n"; - } else { - index_out << " },\n"; - } const size_t num_schedule_features = ScheduleFeatures::num_features(); const size_t num_pipeline_features = PipelineFeatures::num_features(); const auto &sched_feat = features.get(&s); @@ -147,26 +119,24 @@ void State::save_featurization(const FunctionDAG &dag, const MachineParams ¶ } out.write((const char *)buf, sizeof(buf)); - offset += (num_schedule_features + num_pipeline_features); } } - index_out << "]}\n"; } -bool State::calculate_cost(const FunctionDAG &dag, const MachineParams ¶ms, +bool State::calculate_cost(const FunctionDAG &dag, const Adams2019Params ¶ms, CostModel *cost_model, const CachingOptions &cache_options, - int64_t memory_limit, bool verbose) { + int verbosity) { StageMap features; compute_featurization(dag, params, &features, cache_options); cost = 0.0f; - if (verbose) { + if (verbosity <= aslog::aslog_level()) { for (auto it = features.begin(); it != features.end(); it++) { const auto &stage = *(it.key()); const auto &feat = it.value(); - aslog(0) << "Schedule features for " << stage.stage.name() << "\n"; - feat.dump(); + aslog(verbosity) << "Schedule features for " << stage.stage.name() << "\n"; + feat.dump(aslog(verbosity).get_ostream()); } } @@ -190,7 +160,7 @@ bool State::calculate_cost(const FunctionDAG &dag, const MachineParams ¶ms, } // Apply the hard limit on memory use - if (memory_limit >= 0) { + if (params.memory_limit >= 0) { int64_t mem_used = (int64_t)features.begin().value().working_set_at_root; for (auto it = features.begin(); it != features.end(); it++) { if (it.key()->node->is_output || @@ -199,7 +169,7 @@ bool State::calculate_cost(const FunctionDAG &dag, const MachineParams ¶ms, mem_used -= it.value().bytes_at_production; } } - if (mem_used > memory_limit) { + if (mem_used > params.memory_limit) { cost = 1e50; return false; } @@ -228,13 +198,10 @@ IntrusivePtr State::make_child() const { return s; } -#include -using namespace std; // Generate the successor states to this state void State::generate_children(const FunctionDAG &dag, - const MachineParams ¶ms, + const Adams2019Params ¶ms, CostModel *cost_model, - int64_t memory_limit, std::function &&)> &accept_child, Cache *cache) const { @@ -247,7 +214,7 @@ void State::generate_children(const FunctionDAG &dag, int next_node = num_decisions_made / 2; int phase = num_decisions_made % 2; - if (!may_subtile()) { + if (params.disable_subtiling) { // When emulating the older search space, we do all // parallelizing last, so that it is independent of the // tiling decisions. @@ -267,7 +234,7 @@ void State::generate_children(const FunctionDAG &dag, // We don't need to schedule nodes that represent inputs, // and there are no other decisions to be made about them // at this time. - // aslog(0) << "Skipping over scheduling input node: " << node->func.name() << "\n"; + // aslog(1) << "Skipping over scheduling input node: " << node->func.name() << "\n"; auto child = make_child(); child->num_decisions_made++; accept_child(std::move(child)); @@ -275,17 +242,19 @@ void State::generate_children(const FunctionDAG &dag, } if (!node->outgoing_edges.empty() && !root->calls(node)) { - aslog(0) << "In state:\n"; - dump(); - aslog(0) << node->func.name() << " is consumed by:\n"; + std::ostringstream err; + err << "In state:\n"; + dump(err); + err << node->func.name() << " is consumed by:\n"; for (const auto *e : node->outgoing_edges) { - aslog(0) << e->consumer->name << "\n"; - aslog(0) << "Which in turn consumes:\n"; + err << e->consumer->name << "\n"; + err << "Which in turn consumes:\n"; for (const auto *e2 : e->consumer->incoming_edges) { - aslog(0) << " " << e2->producer->func.name() << "\n"; + err << " " << e2->producer->func.name() << "\n"; } } - internal_error << "Pipeline so far doesn't use next Func: " << node->func.name() << "\n"; + err << "Pipeline so far doesn't use next Func: " << node->func.name() << "\n"; + internal_error << err.str(); } int num_children = 0; @@ -301,7 +270,7 @@ void State::generate_children(const FunctionDAG &dag, new_root->inline_func(node); child->root = new_root; child->num_decisions_made++; - if (child->calculate_cost(dag, params, cost_model, cache->options, memory_limit)) { + if (child->calculate_cost(dag, params, cost_model, cache->options)) { num_children++; accept_child(std::move(child)); } @@ -383,7 +352,7 @@ void State::generate_children(const FunctionDAG &dag, auto child = make_child(); child->root = std::move(n); child->num_decisions_made++; - if (child->calculate_cost(dag, params, cost_model, cache->options, memory_limit)) { + if (child->calculate_cost(dag, params, cost_model, cache->options)) { num_children++; accept_child(std::move(child)); } @@ -416,7 +385,7 @@ void State::generate_children(const FunctionDAG &dag, } else { internal_assert(pure_size); - if (cache->add_memoized_blocks(this, accept_child, node, num_children, dag, params, cost_model, memory_limit)) { + if (cache->add_memoized_blocks(this, accept_child, node, num_children, dag, params, cost_model)) { return; // successfully added cached states. } @@ -504,7 +473,7 @@ void State::generate_children(const FunctionDAG &dag, } for (const auto &o : options) { - if (num_children >= 1 && (o.idle_core_wastage > 1.2 || !may_subtile())) { + if (num_children >= 1 && (o.idle_core_wastage > 1.2 || params.disable_subtiling)) { // We have considered several options, and the // remaining ones leave lots of cores idle. break; @@ -515,7 +484,7 @@ void State::generate_children(const FunctionDAG &dag, new_root->copy_from(*root); for (auto &c : new_root->children) { if (c->node == node) { - if (may_subtile()) { + if (!params.disable_subtiling) { c = c->parallelize_in_tiles(params, o.tiling, new_root); } else { // We're emulating the old @@ -541,7 +510,7 @@ void State::generate_children(const FunctionDAG &dag, } child->root = new_root; child->num_decisions_made++; - if (child->calculate_cost(dag, params, cost_model, cache->options, memory_limit)) { + if (child->calculate_cost(dag, params, cost_model, cache->options)) { num_children++; accept_child(std::move(child)); // Will early return if block caching is not enabled. @@ -552,56 +521,47 @@ void State::generate_children(const FunctionDAG &dag, } if (num_children == 0) { - aslog(0) << "Warning: Found no legal way to schedule " + aslog(1) << "Warning: Found no legal way to schedule " << node->func.name() << " in the following State:\n"; - dump(); + dump(aslog(1).get_ostream()); // All our children died. Maybe other states have had // children. Carry on. } } -void State::dump() const { - aslog(0) << "State with cost " << cost << ":\n"; - root->dump("", nullptr); - aslog(0) << schedule_source; - aslog(0) << "----- Python schedule -----"; - aslog(0) << python_schedule_source; -} - -string State::dump(bool dummy) const { - return root->dump("", nullptr, dummy); +void State::dump(std::ostream &os) const { + os << "State with cost " << cost << ":\n"; + root->dump(os, "", nullptr); + os << schedule_source; } // Apply the schedule represented by this state to a Halide // Pipeline. Also generate source code for the schedule for the // user to copy-paste to freeze this schedule as permanent artifact. -void State::apply_schedule(const FunctionDAG &dag, const MachineParams ¶ms) { +void State::apply_schedule(const FunctionDAG &dag, const Adams2019Params ¶ms) { StageMap> state_map; root->apply(LoopLevel::root(), state_map, params.parallelism, 0, nullptr, nullptr); - std::ostringstream src, python_src; + std::ostringstream src; // Print handles for all the Funcs int i = (int)(dag.nodes.size() - 1); for (const auto &n : dag.nodes) { if (!n.is_input) { - src << "Func " << conform_name(n.func.name()) << " = pipeline.get_func(" << i << ");\n"; - python_src << conform_name(n.func.name()) << " = pipeline.get_func(" << i << ")\n"; + src << "Func " << n.func.name() << " = pipeline.get_func(" << i << ");\n"; } i--; } // Gather all Vars and RVars so that we can declare them in the emitted source - map vars, rvars, python_vars, python_rvars; + map vars, rvars; for (auto &p : state_map) { for (auto &v : p.second->vars) { if (v.exists) { if (v.var.is_rvar) { rvars.emplace(v.var.name(), v.accessor); - python_rvars.emplace(v.var.name(), v.python_accessor); } else { vars.emplace(v.var.name(), v.accessor); - python_vars.emplace(v.var.name(), v.python_accessor); } } } @@ -609,36 +569,18 @@ void State::apply_schedule(const FunctionDAG &dag, const MachineParams ¶ms) if (!vars.empty()) { for (const auto &p : vars) { if (p.second.empty()) { - src << "Var " << conform_name(p.first) << "(\"" << p.first << "\");\n"; + src << "Var " << p.first << "(\"" << p.first << "\");\n"; } else { - src << "Var " << conform_name(p.first) << "(" << p.second << ");\n"; + src << "Var " << p.first << "(" << p.second << ");\n"; } } } if (!rvars.empty()) { for (const auto &p : rvars) { if (p.second.empty()) { - src << "RVar " << conform_name(p.first) << "(\"" << p.first << "\");\n"; - } else { - src << "RVar " << conform_name(p.first) << "(" << p.second << ");\n"; - } - } - } - if (!python_vars.empty()) { - for (const auto &p : python_vars) { - if (p.second.empty()) { - python_src << conform_name(p.first) << " = hl.Var(\"" << p.first << "\")\n"; + src << "RVar " << p.first << "(\"" << p.first << "\");\n"; } else { - python_src << conform_name(p.first) << " = hl.Var(" << p.second << ")\n"; - } - } - } - if (!python_rvars.empty()) { - for (const auto &p : python_rvars) { - if (p.second.empty()) { - python_src << conform_name(p.first) << " = hl.RVar(\"" << p.first << "\")\n"; - } else { - python_src << conform_name(p.first) << " = hl.RVar(" << p.second << ")\n"; + src << "RVar " << p.first << "(" << p.second << ");\n"; } } } @@ -653,7 +595,6 @@ void State::apply_schedule(const FunctionDAG &dag, const MachineParams ¶ms) // Do all the reorders and pick which vars to // parallelize. vector vars; - int64_t parallel_tasks = 1; vector parallel_vars; bool any_parallel_vars = false, any_parallel_rvars = false; for (auto it = p.second->vars.rbegin(); it != p.second->vars.rend(); it++) { @@ -665,31 +606,25 @@ void State::apply_schedule(const FunctionDAG &dag, const MachineParams ¶ms) } any_parallel_rvars |= it->var.is_rvar; any_parallel_vars |= !it->var.is_rvar; - parallel_tasks *= it->extent; parallel_vars.push_back(it->var); } if (p.second->vars.size() > 1) { p.second->schedule_source << "\n .reorder("; - p.second->python_schedule_source << " \\\n .reorder("; bool first = true; for (auto &v : p.second->vars) { if (v.exists) { vars.push_back(v.var); if (!first) { p.second->schedule_source << ", "; - p.second->python_schedule_source << ", "; } else { p.second->schedule_source << "{"; - p.second->python_schedule_source << " "; } first = false; - p.second->schedule_source << conform_name(v.var.name()); - p.second->python_schedule_source << conform_name(v.var.name()); + p.second->schedule_source << v.var.name(); } } p.second->schedule_source << "})"; - p.second->python_schedule_source << " )"; stage.reorder(vars); } @@ -700,23 +635,18 @@ void State::apply_schedule(const FunctionDAG &dag, const MachineParams ¶ms) for (size_t i = 1; i < parallel_vars.size(); i++) { // Outermost, and next outermost. Preserve the inner // name to not invalidate any compute_ats. - p.second->schedule_source << "\n .fuse(" << conform_name(parallel_vars[i].name()) - << ", " << conform_name(parallel_vars[i - 1].name()) - << ", " << conform_name(parallel_vars[i].name()) << ")"; - p.second->python_schedule_source << " \\\n .fuse(" << conform_name(parallel_vars[i].name()) - << ", " << conform_name(parallel_vars[i - 1].name()) - << ", " << conform_name(parallel_vars[i].name()) << ")"; + p.second->schedule_source << "\n .fuse(" << parallel_vars[i].name() + << ", " << parallel_vars[i - 1].name() + << ", " << parallel_vars[i].name() << ")"; stage.fuse(parallel_vars[i], parallel_vars[i - 1], parallel_vars[i]); } if (!parallel_vars.empty()) { - p.second->schedule_source << "\n .parallel(" << conform_name(parallel_vars.back().name()) << ")"; - p.second->python_schedule_source << " \\\n .parallel(" << conform_name(parallel_vars.back().name()) << ")"; + p.second->schedule_source << "\n .parallel(" << parallel_vars.back().name() << ")"; stage.parallel(parallel_vars.back()); } } else { for (const auto &v : parallel_vars) { - p.second->schedule_source << "\n .parallel(" << conform_name(v.name()) << ")"; - p.second->python_schedule_source << " \\\n .parallel(" << conform_name(v.name()) << ")"; + p.second->schedule_source << "\n .parallel(" << v.name() << ")"; stage.parallel(v); } } @@ -728,19 +658,15 @@ void State::apply_schedule(const FunctionDAG &dag, const MachineParams ¶ms) std::swap(storage_vars[i], storage_vars[i - 1]); } p.second->schedule_source << "\n .reorder_storage("; - p.second->python_schedule_source << " \\\n .reorder_storage("; bool first = true; for (const auto &v : storage_vars) { if (!first) { p.second->schedule_source << ", "; - p.second->python_schedule_source << ", "; } first = false; - p.second->schedule_source << conform_name(v.name()); - p.second->python_schedule_source << conform_name(v.name()); + p.second->schedule_source << v.name(); } p.second->schedule_source << ")"; - p.second->python_schedule_source << ")"; Func(p.first->node->func).reorder_storage(storage_vars); } @@ -748,24 +674,16 @@ void State::apply_schedule(const FunctionDAG &dag, const MachineParams ¶ms) src << p.first->name << p.second->schedule_source.str() << ";\n"; - python_src << p.first->name - << p.second->python_schedule_source.str() - << "\n\n"; } // Sanitize the names of things to make them legal source code. schedule_source = src.str(); - python_schedule_source = python_src.str(); - auto sanitize = [](std::string& source) { - bool in_quotes = false; - for (auto &c : source) { - in_quotes ^= (c == '"'); - if (!in_quotes && c == '$') { - c = '_'; - } + bool in_quotes = false; + for (auto &c : schedule_source) { + in_quotes ^= (c == '"'); + if (!in_quotes && c == '$') { + c = '_'; } - }; - sanitize(schedule_source); - sanitize(python_schedule_source); + } } } // namespace Autoscheduler diff --git a/src/autoschedulers/adams2019/State.h b/src/autoschedulers/adams2019/State.h index ba831574832e..4e7a31a63006 100644 --- a/src/autoschedulers/adams2019/State.h +++ b/src/autoschedulers/adams2019/State.h @@ -34,9 +34,6 @@ struct State { // The C++ source code of the generated schedule for this State. // Computed if `apply_schedule` is called. string schedule_source; - // The Python source code of the generated schedule for this State. - // Computed if `apply_schedule` is called. - string python_schedule_source; // The number of times a cost is enqueued into the cost model, // for all states. @@ -55,23 +52,22 @@ struct State { // Compute the featurization of this state (based on `root`), // and store features in `features`. Defers to `root->compute_features()`. void compute_featurization(const FunctionDAG &dag, - const MachineParams ¶ms, + const Adams2019Params ¶ms, StageMap *features, const CachingOptions &cache_options); // Calls `compute_featurization` and prints those features to `out`. void save_featurization(const FunctionDAG &dag, - const MachineParams ¶ms, + const Adams2019Params ¶ms, const CachingOptions &cache_options, - std::ostream &out, - std::ostream &index_out); + std::ostream &out); // Performs some pruning to decide if this state is worth queuing in // the cost_model. If it is, calls `cost_model->enqueue` and returns true, // otherwise sets `cost` equal to a large value and returns false. - bool calculate_cost(const FunctionDAG &dag, const MachineParams ¶ms, + bool calculate_cost(const FunctionDAG &dag, const Adams2019Params ¶ms, CostModel *cost_model, const CachingOptions &cache_options, - int64_t memory_limit, bool verbose = false); + int verbosity = 99); // Make a child copy of this state. The loop nest is const (we // make mutated copies of it, rather than mutating it), so we can @@ -83,23 +79,19 @@ struct State { // If they are not pruned by `calculate_cost()`, // then calls `accept_child()` on them. void generate_children(const FunctionDAG &dag, - const MachineParams ¶ms, + const Adams2019Params ¶ms, CostModel *cost_model, - int64_t memory_limit, std::function &&)> &accept_child, Cache *cache) const; - // Dumps cost, the `root` LoopNest, and then `schedule_source` to `aslog(0)`. - void dump() const; - - // Dump schedule "feature vector" - string dump(bool dummy) const; + // Dumps cost, the `root` LoopNest, and then `schedule_source` to `os`. + void dump(std::ostream &os) const; // Apply the schedule represented by this state to a Halide // Pipeline. Also generate source code for the schedule for the // user to copy-paste to freeze this schedule as permanent artifact. // Also fills `schedule_source`. - void apply_schedule(const FunctionDAG &dag, const MachineParams ¶ms); + void apply_schedule(const FunctionDAG &dag, const Adams2019Params ¶ms); }; } // namespace Autoscheduler diff --git a/src/autoschedulers/adams2019/autotune_loop.sh b/src/autoschedulers/adams2019/autotune_loop.sh index f3b731b37a13..92f419c005de 100755 --- a/src/autoschedulers/adams2019/autotune_loop.sh +++ b/src/autoschedulers/adams2019/autotune_loop.sh @@ -65,14 +65,6 @@ else echo Copying starting weights from ${START_WEIGHTS_FILE} to ${WEIGHTS} fi -# We could add this unconditionally, but it's easier to wade thru -# results if we only add if needed -for F in disable_llvm_loop_opt; do - if [[ ! ${HL_TARGET} =~ .*${F}.* ]]; then - HL_TARGET="${HL_TARGET}-${F}" - fi -done - # A batch of this many samples is built in parallel, and then # benchmarked serially. BATCH_SIZE=32 @@ -87,6 +79,8 @@ if [ $(uname -s) = "Darwin" ] && ! which $TIMEOUT_CMD 2>&1 >/dev/null; then fi fi +PLUGIN_EXT=so + # Build a single featurization of the pipeline with a random schedule make_featurization() { D=${1} @@ -105,23 +99,22 @@ make_featurization() { dropout=1 # 1% chance of operating entirely greedily beam=1 fi - HL_SEED=${SEED} \ - HL_WEIGHTS_DIR=${WEIGHTS} \ - HL_RANDOM_DROPOUT=${dropout} \ - HL_BEAM_SIZE=${beam} \ - HL_MACHINE_PARAMS=32,24000000,40 \ - ${TIMEOUT_CMD} -k ${COMPILATION_TIMEOUT} ${COMPILATION_TIMEOUT} \ + ${TIMEOUT_CMD} -k ${COMPILATION_TIMEOUT} ${COMPILATION_TIMEOUT} \ ${GENERATOR} \ -g ${PIPELINE} \ -f ${FNAME} \ -o ${D} \ - -e stmt,assembly,static_library,c_header,registration,schedule,featurization,python_schedule \ + -e stmt,assembly,static_library,c_header,registration,schedule,featurization \ target=${HL_TARGET} \ - auto_schedule=true \ ${EXTRA_GENERATOR_ARGS} \ - -p ${AUTOSCHED_BIN}/libautoschedule_adams2019.so \ - -s Adams2019 \ - 2> ${D}/compile_log.txt || echo "Compilation failed or timed out for ${D}" + -p ${AUTOSCHED_BIN}/libautoschedule_adams2019.${PLUGIN_EXT} \ + autoscheduler=Adams2019 \ + autoscheduler.parallelism=32 \ + autoscheduler.beam_size=${beam} \ + autoscheduler.random_dropout=${dropout} \ + autoscheduler.random_dropout_seed=${SEED} \ + autoscheduler.weights_path=${WEIGHTS} \ + 2> ${D}/compile_log.txt || echo "Compilation failed or timed out for ${D}" # We don't need image I/O for this purpose, @@ -227,8 +220,7 @@ for ((BATCH_ID=$((FIRST+1));BATCH_ID<$((FIRST+1+NUM_BATCHES));BATCH_ID++)); do --initial_weights=${WEIGHTS} \ --weights_out=${WEIGHTS} \ --best_benchmark=${SAMPLES}/best.${PIPELINE}.benchmark.txt \ - --best_schedule=${SAMPLES}/best.${PIPELINE}.schedule.h \ - --best_python_schedule=${SAMPLES}/best_${PIPELINE}_schedule.py + --best_schedule=${SAMPLES}/best.${PIPELINE}.schedule.h done echo Batch ${BATCH_ID} took ${SECONDS} seconds to compile, benchmark, and retrain diff --git a/src/autoschedulers/adams2019/cost_model_generator.cpp b/src/autoschedulers/adams2019/cost_model_generator.cpp index 65d9dff386ec..4ab6b59c1b57 100644 --- a/src/autoschedulers/adams2019/cost_model_generator.cpp +++ b/src/autoschedulers/adams2019/cost_model_generator.cpp @@ -123,7 +123,7 @@ class CostModel : public Generator> { using Input = GeneratorInput; template using Output = GeneratorOutput; - using Generator>::auto_schedule; + using Generator>::using_autoscheduler; using Generator>::get_pipeline; // Number of pipeline stages @@ -434,35 +434,10 @@ class CostModel : public Generator> { Expr r1 = true_runtime(n) * scale; // Invert them to get relative throughput, and compute L2 loss. - // Expr delta = pow(1.0f / max(p1, 1e-10f) - 1.0f / r1, 2); - // Instead of the term above, we will divide the delta by the 1/r1, - // emphasizing that getting smaller runtime predictions wrong would - // contribute more to the error term than getting larger predictions wrong. - // We will experiment with adding powers and coefficients to r1. - //Expr delta = pow(1.0f / max(p1, 1e-10f) - 1.0f / r1, 2) / (r1*r1*r1*r1); - //Expr delta = pow(p1 - r1, 2) / (r1*r1*r1*r1); - // Expr delta = pow(p1 - r1, 2); - //Expr delta = log(cosh(p1 - r1)); - //Expr delta = pow( log(p1 + 1) - log(r1 + 1), 2) / (r1*r1*r1*r1); - //Expr delta = r1*r1 - 0.0001f*p1*p1; // after 20th batch things went downhill; try 128 samples per batch - //Expr delta = r1*r1 - 0.00001f*p1*p1 + log(cosh(p1 - r1)); // NOT BAD AT ALL! - //Expr delta = r1*r1 - 0.00001f*p1*p1 + 2.0f*(r1*log(r1) - r1*log(max(p1, 1e-10f))); // Last term is Kullback-Leibler divergence - //Expr delta = r1*r1 + 0.001f*p1*p1; // try coefficients in front of p1*p1; also try 1/(r1*r1) ... - // Expr delta = r1*r1 + p1*p1; // try coefficients in front of p1*p1; also try 1/(r1*r1) ... - //Expr delta = (r1*r1 + 0.000001f*p1*p1)*(r1*log(r1) - r1*log(max(p1, 1e-10f))); // try coefficients in front of p1*p1; also try 1/(r1*r1) ... - //Expr delta = r1*r1 - 0.000001f*p1*p1 + log(cosh(p1 - r1)); // - //Expr delta = exp(-0.22f*(r1-p1)) + 0.22f*(r1-p1) - 1.0f; - //Expr delta = exp(0.22f*(r1-p1)) - 0.22f*(r1-p1) - 1.0f; - //Expr delta = 0.6f*pow(r1-p1, 2)/(1.0f + exp(9.0f*(r1-p1))) + 0.4f; // DOES NOT WORK - //Expr delta = exp(-0.22f*(1.0f/r1 - 1.0f/max(p1, 1e-10f))) + 0.22f*(1.0f/r1 - 1.0f/max(p1, 1e-10f)) - 1.0f; - // Expr delta = 17.0f*(exp(-0.22f*(0.5f*r1-p1)) + 0.22f*(0.5f*r1-p1) - 1.0f); // Batch 20 is very interesting with 6 points below 1.6 at 16 sample run - //Expr delta = 17.0f*(exp(-0.22f*(0.3f*r1-p1)) + 0.22f*(0.3f*r1-p1) - 1.0f) + r1*r1; // Interesting! - Expr delta = 17.0f*(exp(-0.22f*(0.25f*r1-p1)) + 0.22f*(0.25f*r1-p1) - 1.0f) + r1*r1; // + Expr delta = pow(1.0f / max(p1, 1e-10f) - 1.0f / r1, 2); // Add the regulization with a small weight. err(n) = delta + 1e-5f * regularize; - //err(n) = delta + 0.0f * regularize; - //err(n) = delta; // Sum the errors over the batch. Expr loss = sum(err(r_batch)); @@ -507,9 +482,9 @@ class CostModel : public Generator> { true_runtime.set_estimates({{0, 80}}); // SCHEDULE - if (training && !auto_schedule) { + if (training && !using_autoscheduler()) { do_cost_model_schedule(get_pipeline()); - } else if (auto_schedule) { + } else if (using_autoscheduler()) { // Do nothing. } else { // We just write down a good schedule for diff --git a/src/autoschedulers/adams2019/featurization_to_sample.cpp b/src/autoschedulers/adams2019/featurization_to_sample.cpp index 60b674d1f010..fa94cb840cb9 100644 --- a/src/autoschedulers/adams2019/featurization_to_sample.cpp +++ b/src/autoschedulers/adams2019/featurization_to_sample.cpp @@ -3,62 +3,40 @@ #include #include - -const int ERROR = -1; -const int SUCCESS = 0; - -enum Args { - Executable, - InFeaturization, - Runtime, - PipelineId, - ScheduleId, - OutSample, - NumberOfArgs -}; - // A sample is a featurization + a runtime + some ids, all together in one file. // This utility concats the runtime and ids onto a featurization to produce a sample. - -// Sample command line: -// featurization_to_sample onnx_batch_0006_sample_0027.featurization 0.0022211699999999997 onnx 00060027 onnx_batch_0006_sample_0027.sample int main(int argc, char **argv) { - if (argc != NumberOfArgs) { + if (argc != 6) { std::cout << "Usage: featurization_to_sample in.featurization runtime pipeline_id schedule_id out.sample\n"; - return ERROR; + return -1; } - // Processing in.featurization parameter - std::ifstream src(argv[InFeaturization], std::ios::binary); + std::ifstream src(argv[1], std::ios::binary); if (!src) { - std::cerr << "Unable to open input file: " << argv[InFeaturization] << "\n"; - return ERROR; + std::cerr << "Unable to open input file: " << argv[1] << "\n"; + return -1; } - // Processing out.sample parameter - std::ofstream dst(argv[OutSample], std::ios::binary); + std::ofstream dst(argv[5], std::ios::binary); if (!dst) { - std::cerr << "Unable to open output file: " << argv[OutSample] << "\n"; - return ERROR; + std::cerr << "Unable to open output file: " << argv[5] << "\n"; + return -1; } dst << src.rdbuf(); // Input runtime value is presumed to be in seconds, // but sample file stores times in milliseconds. - // processing run time parameter - float runtime = atof(argv[Runtime]) * 1000.f; - // processing pipeline_id parameter - int32_t pipeline_id = atoi(argv[PipelineId]); - // processing schedule_id parameter - int32_t schedule_id = atoi(argv[ScheduleId]); + float r = atof(argv[2]) * 1000.f; + int32_t pid = atoi(argv[3]); + int32_t sid = atoi(argv[4]); - dst.write((const char *)&runtime, sizeof(float)); - dst.write((const char *)&pipeline_id, sizeof(int32_t)); - dst.write((const char *)&schedule_id, sizeof(int32_t)); + dst.write((const char *)&r, 4); + dst.write((const char *)&pid, 4); + dst.write((const char *)&sid, 4); src.close(); dst.close(); - return SUCCESS; + return 0; } diff --git a/src/autoschedulers/adams2019/included_schedule_file_generator.cpp b/src/autoschedulers/adams2019/included_schedule_file_generator.cpp index 21ee6ec0918c..cdd2bc7f6bf3 100644 --- a/src/autoschedulers/adams2019/included_schedule_file_generator.cpp +++ b/src/autoschedulers/adams2019/included_schedule_file_generator.cpp @@ -37,7 +37,7 @@ struct IncludedScheduleFile : public Halide::Generator { relu.set_estimates({{0, CO}, {0, W}, {0, H}, {0, N}}); // Schedule - if (auto_schedule) { + if (using_autoscheduler()) { // nothing } else { #if defined(GENERATING_SCHEDULE) diff --git a/src/autoschedulers/adams2019/retrain_cost_model.cpp b/src/autoschedulers/adams2019/retrain_cost_model.cpp index d62bab16f1ae..0f8e48a16531 100644 --- a/src/autoschedulers/adams2019/retrain_cost_model.cpp +++ b/src/autoschedulers/adams2019/retrain_cost_model.cpp @@ -33,7 +33,6 @@ struct Flags { bool randomize_weights = false; string best_benchmark_path; string best_schedule_path; - string best_python_schedule_path; Flags(int argc, char **argv) { cmdline::parser a; @@ -49,7 +48,6 @@ struct Flags { a.add("num_cores"); a.add("best_benchmark"); a.add("best_schedule"); - a.add("best_python_schedule"); a.parse_check(argc, argv); // exits if parsing fails @@ -60,7 +58,6 @@ struct Flags { randomize_weights = a.exist("randomize_weights") && a.get("randomize_weights"); best_benchmark_path = a.get("best_benchmark"); best_schedule_path = a.get("best_schedule"); - best_python_schedule_path = a.get("best_python_schedule"); if (epochs <= 0) { std::cerr << "--epochs must be specified and > 0.\n"; @@ -353,23 +350,17 @@ map load_samples(const Flags &flags) { f.close(); assert(!f.fail()); } - - auto copy_best_schedule = [&best_path](const std::string& schedule_path, const std::string& extension) { - if (!schedule_path.empty()) { - // best_path points to a .sample file; look for a .schedule.h file in the same dir - size_t dot = best_path.rfind('.'); - assert(dot != string::npos && best_path.substr(dot) == ".sample"); - string schedule_file = best_path.substr(0, dot) + extension; - std::ifstream src(schedule_file); - std::ofstream dst(schedule_path); - dst << src.rdbuf(); - assert(!src.fail()); - assert(!dst.fail()); - } - }; - - copy_best_schedule(flags.best_schedule_path, ".schedule.h"); - copy_best_schedule(flags.best_python_schedule_path, "_schedule.py"); + if (!flags.best_schedule_path.empty()) { + // best_path points to a .sample file; look for a .schedule.h file in the same dir + size_t dot = best_path.rfind('.'); + assert(dot != string::npos && best_path.substr(dot) == ".sample"); + string schedule_file = best_path.substr(0, dot) + ".schedule.h"; + std::ifstream src(schedule_file); + std::ofstream dst(flags.best_schedule_path); + dst << src.rdbuf(); + assert(!src.fail()); + assert(!dst.fail()); + } return result; } diff --git a/src/autoschedulers/adams2019/test-temp.cpp b/src/autoschedulers/adams2019/test-temp.cpp deleted file mode 100644 index 24f070618aa4..000000000000 --- a/src/autoschedulers/adams2019/test-temp.cpp +++ /dev/null @@ -1,484 +0,0 @@ -#include "Halide.h" - -using namespace Halide; - -int main(int argc, char **argv) { - if (argc != 2) { - fprintf(stderr, "Usage: %s \n", argv[0]); - return 1; - } - - load_plugin(argv[1]); - - MachineParams params(32, 16000000, 40); - // Use a fixed target for the analysis to get consistent results from this test. - Target target("x86-64-linux-sse41-avx-avx2"); - - Var x("x"), y("y"); - - if (true) { - // In a point-wise pipeline, everything should be fully fused. - Func f("f"), g("g"), h("h"); - f(x, y) = (x + y) * (x + y); - g(x, y) = f(x, y) * 2 + 1; - h(x, y) = g(x, y) * 2 + 1; - - h.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); - - Pipeline(h).auto_schedule(target, params); - } - - if (true) { - // In a pipeline with huge expensive stencils and low memory costs, nothing should be fused - Func f("f"), g("g"), h("h"); - f(x, y) = (x + y) * (x + 2 * y) * (x + 3 * y) * (x + 4 * y) * (x + 5 * y); - Expr e = 0; - for (int i = 0; i < 100; i++) { - e += f(x + i * 10, y + i * 10); - } - g(x, y) = e; - e = 0; - for (int i = 0; i < 100; i++) { - e += g(x + i * 10, y + i * 10); - } - h(x, y) = e; - - h.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); - - Pipeline(h).auto_schedule(target, params); - } - - if (true) { - // In a pipeline with moderate isotropic stencils, there should be some square tiling - Func f("f"), h("h"); - f(x, y) = (x + y) * (x + 2 * y) * (x + 3 * y); - h(x, y) = (f(x - 9, y - 9) + f(x, y - 9) + f(x + 9, y - 9) + - f(x - 9, y) + f(x, y) + f(x + 9, y) + - f(x - 9, y + 9) + f(x, y + 9) + f(x + 9, y - 9)); - - h.set_estimate(x, 0, 2048).set_estimate(y, 0, 2048); - - Pipeline(h).auto_schedule(target, params); - } - - // Smaller footprint stencil -> smaller tiles - if (true) { - Func f("f"), g("g"), h("h"); - f(x, y) = (x + y) * (x + 2 * y) * (x + 3 * y); - h(x, y) = (f(x - 1, y - 1) + f(x, y - 1) + f(x + 1, y - 1) + - f(x - 1, y) + f(x, y) + f(x + 1, y) + - f(x - 1, y + 1) + f(x, y + 1) + f(x + 1, y - 1)); - - h.set_estimate(x, 0, 2048).set_estimate(y, 0, 2048); - - Pipeline(h).auto_schedule(target, params); - } - - // A stencil chain - if (true) { - const int N = 8; - Func f[N]; - f[0](x, y) = (x + y) * (x + 2 * y) * (x + 3 * y); - for (int i = 1; i < N; i++) { - Expr e = 0; - for (int dy = -2; dy <= 2; dy++) { - for (int dx = -2; dx <= 2; dx++) { - e += f[i - 1](x + dx, y + dy); - } - } - f[i](x, y) = e; - } - f[N - 1].set_estimate(x, 0, 2048).set_estimate(y, 0, 2048); - - Pipeline(f[N - 1]).auto_schedule(target, params); - } - - // An outer product - if (true) { - Buffer a(2048), b(2048); - Func f; - f(x, y) = a(x) * b(y); - - f.set_estimate(x, 0, 2048).set_estimate(y, 0, 2048); - - Pipeline(f).auto_schedule(target, params); - } - - // A separable downsample that models the start of local_laplacian - if (true) { - Buffer in(2048, 2048); - Var k; - Func orig("orig"), expensive("expensive"), downy("downy"), downx("downx"); - Expr e = 0; - for (int i = 0; i < 100; i++) { - e += 1; - e *= e; - } - orig(x, y) = e; - expensive(x, y, k) = orig(x, y) * orig(x, y) + (x + orig(x, y)) * (1 + orig(x, y)) + sqrt(k + orig(x, y)); - downy(x, y, k) = expensive(x, 2 * y - 1, k) + expensive(x, 2 * y, k) + expensive(x, 2 * y + 1, k) + expensive(x, 2 * y + 2, k); - downx(x, y, k) = downy(2 * x - 1, y, k) + downy(2 * x, y, k) + downy(2 * x + 1, y, k) + downy(2 * x + 2, y, k); - downx.set_estimate(x, 1, 1022).set_estimate(y, 1, 1022).set_estimate(k, 0, 256); - - Pipeline(downx).auto_schedule(target, params); - } - - // A Func with multiple stages, some of which include additional loops - if (true) { - Buffer a(1024, 1024); - Func f("multiple_stages"), g("g"), h("h"); - Var x, y; - h(x, y) = pow(x, y); - f(x, y) = a(x, y) * 2; - f(x, y) += 17; - RDom r(0, 10); - f(x, y) += r * h(x, y); - f(x, y) *= 2; - f(0, y) = 23.0f; - g(x, y) = f(x - 1, y - 1) + f(x + 1, y + 1); - - g.set_estimate(x, 1, 1022).set_estimate(y, 1, 1022); - - Pipeline(g).auto_schedule(target, params); - } - - if (true) { - // A scan with pointwise stages before and after - Buffer a(1024, 1024); - Func before[5]; - Func after[5]; - Func s("scan"); - Var x, y; - before[0](x, y) = x + y; - for (int i = 1; i < 5; i++) { - before[i](x, y) = before[i - 1](x, y) + 1; - } - RDom r(1, 1023); - s(x, y) = before[4](x, y); - s(r, y) += s(r - 1, y); - after[0](x, y) = s(y, x) + s(y, x + 100); - for (int i = 1; i < 5; i++) { - after[i](x, y) = after[i - 1](x, y) + 1; - } - - after[4].set_estimate(x, 0, 1024).set_estimate(y, 0, 1024); - - Pipeline(after[4]).auto_schedule(target, params); - } - - if (true) { - Func f_u8("f_u8"); - Func f_u64_1("f_u64_1"); - Func f_u64_2("f_u64_2"); - Buffer a(1024 * 1024 + 2); - - Var x; - f_u8(x) = (min(a(x) + 1, 17) * a(x + 1) + a(x + 2)) * a(x) * a(x) * a(x + 1) * a(x + 1); - f_u64_1(x) = cast(f_u8(x)) + 1; - f_u64_2(x) = f_u64_1(x) * 3; - - // Ignoring the types, it would make sense to inline - // everything into f_64_2 but this would vectorize fairly - // narrowly, which is a waste of work for the first Func. - - f_u64_2.set_estimate(x, 0, 1024 * 1024); - - Pipeline(f_u64_2).auto_schedule(target, params); - } - - if (true) { - Buffer im_a(1024, 1024, "a"), im_b(1024, 1024, "b"); - im_a.fill(0.0f); - im_b.fill(0.0f); - - Func c("c"), a("a"), b("b"); - Var i, j; - a(j, i) = im_a(j, i); // TODO: Add wrappers to the search space - b(j, i) = im_b(j, i); - RDom k(0, 1024); - c(j, i) += a(k, i) * b(j, k); - Func out("out"); - out(j, i) = c(j, i); - - out.set_estimate(j, 0, 1024).set_estimate(i, 0, 1024); - - Pipeline(out).auto_schedule(target, params); - } - - if (true) { - // A scan in x followed by a downsample in y, with pointwise stuff in between - const int N = 3; - Buffer a(1024, 1024); - Func p1[N], p2[N], p3[N]; - Func s("scan"); - Var x, y; - p1[0](x, y) = x + y; - for (int i = 1; i < N; i++) { - p1[i](x, y) = p1[i - 1](x, y) + 1; - } - RDom r(1, 1023); - s(x, y) = p1[N - 1](x, y); - s(r, y) += s(r - 1, y); - p2[0](x, y) = s(x, y); - for (int i = 1; i < N; i++) { - p2[i](x, y) = p2[i - 1](x, y) + 1; - } - Func down("downsample"); - down(x, y) = p2[N - 1](x, 2 * y); - p3[0](x, y) = down(x, y); - for (int i = 1; i < N; i++) { - p3[i](x, y) = p3[i - 1](x, y) + 1; - } - - p3[N - 1].set_estimate(x, 0, 1024).set_estimate(y, 0, 1024); - - Pipeline(p3[N - 1]).auto_schedule(target, params); - } - - if (true) { - // A gather that only uses a small portion of a potentially - // large LUT. The number of points computed should be less - // than points computed minimum, and the LUT should be - // inlined, even if it's really expensive. - Func lut("lut"); - Var x; - lut(x) = (x + 1) * (x + 2) * (x + 3) * (x + 4) * (x + 5) * (x + 6); - - Func idx("idx"); - idx(x) = x * (10000 - x); - - Func out("out"); - out(x) = lut(clamp(idx(x), 0, 100000)); - - out.set_estimate(x, 0, 10); - - Pipeline(out).auto_schedule(target, params); - } - - if (true) { - // A schedule where it's insane to not compute inside an rvar - Func f("f"), g("g"); - f(x, y) = x; - f(x, y) += 1; - - RDom r(0, 100); - g(x, y) = 0; - g(x, y) += f(x, 1000 * (y + r)); - - g.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); - - Pipeline(g).auto_schedule(target, params); - } - - if (true) { - // A pipeline where the vectorized dimension should alternate index - Func f("f"), g("g"), h("h"); - f(x, y) = x * y; - - RDom r(-50, 100, -50, 100); - g(x, y) += f(y + r.y, x + r.x); - - h(x, y) += g(y + r.y, x + r.y); - - h.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); - - Pipeline(h).auto_schedule(target, params); - } - - if (true) { - // A no-win scenario in which a Func is going to be read from - // lots of times using a vector gather no matter how it is - // scheduled. - Func in("in"), a("a"), b("b"); - - in(x, y) = sqrt(sqrt(sqrt(sqrt(x * y)))); - - RDom r(-50, 100, -50, 100); - a(x, y) += in(x + r.x, y + r.y); - b(x, y) += in(y + r.y, x + r.x); - - a.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); - b.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); - - Pipeline({a, b}).auto_schedule(target, params); - } - - if (true) { - // Boring memcpy - ImageParam im(Float(32), 2); - Func f("f"), g("g"); - f(x, y) = im(x, y); - g(x, y) = f(x, y); - - g.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); - Pipeline(g).auto_schedule(target, params); - } - - if (true) { - // A load from a tiny input image - ImageParam im(Float(32), 2); - Func f("f"); - f(x, y) = im(x, y) * 7; - - f.set_estimate(x, 0, 3).set_estimate(y, 0, 5); - Pipeline(f).auto_schedule(target, params); - } - - if (true) { - // Lots of dimensions - ImageParam im(Float(32), 7); - Func f("f"); - Var z, w, t, u, v; - f(x, y, z, w, t, u, v) = im(x, y, z, w, t, u, v) * 7; - - f.set_estimate(x, 0, 8) - .set_estimate(y, 0, 9) - .set_estimate(z, 0, 10) - .set_estimate(w, 0, 5) - .set_estimate(t, 0, 3) - .set_estimate(u, 0, 2) - .set_estimate(v, 0, 6); - Pipeline(f).auto_schedule(target, params); - } - - if (true) { - // Long transpose chain. - ImageParam im(Float(32), 2); - Func f("f"), g("g"), h("h"); - - f(x, y) = im(clamp(y * x, 0, 999), x); - g(x, y) = f(clamp(y * x, 0, 999), x); - h(x, y) = g(clamp(y * x, 0, 999), x); - - // Force everything to be compute root by accessing them in two separate outputs - Func out1("out1"), out2("out2"); - out1(x, y) = f(x, y) + g(x, y) + h(x, y); - out2(x, y) = f(x, y) + g(x, y) + h(x, y); - - out1.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); - out2.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); - Pipeline({out1, out2}).auto_schedule(target, params); - } - - if (true) { - ImageParam im(Float(32), 2); - // An inlinable Func used at the start and at the end of a long stencil chain. - const int N = 8; - Func f[N]; - f[0] = Func("inline_me"); - f[0](x, y) = im(x, y); // inline me! - for (int i = 1; i < N; i++) { - Expr e = 0; - for (int dy = -1; dy <= 1; dy++) { - for (int dx = -1; dx <= 1; dx++) { - e += f[i - 1](x + dx, y + dy); - } - } - f[i](x, y) = e; - } - - Func g("output"); - // Access it in a way that makes it insane not to inline. - g(x, y) = f[N - 1](x, y) + f[0](clamp(cast(sin(x) * 10000), 0, 100000), clamp(cast(sin(x * y) * 10000), 0, 100000)); - g.set_estimate(x, 0, 2048).set_estimate(y, 0, 2048); - - Pipeline(g).auto_schedule(target, params); - } - - if (true) { - Func f("f"), g("g"), h("h"); - - f(x, y) = x + y; - ; - g() = f(3, 2); - RDom r(0, 100); - g() += r; - h(x, y) = g() + x + y; - - h.set_estimate(x, 0, 1024).set_estimate(y, 0, 2048); - Pipeline(h).auto_schedule(target, params); - } - - if (true) { - // Vectorizing a pure var in an update using RoundUp - - Func f("f"), g("g"); - - f(x, y) = x + y; - RDom r(0, 10); - f(x, y) += f(x, y) * r; - - g(x, y) = f(x, y); - - g.set_estimate(x, 0, 10).set_estimate(y, 0, 2048); - Pipeline(g).auto_schedule(target, params); - } - - if (true) { - ImageParam im(Float(32), 2); - - // A convolution pyramid - Func up[8], down[8]; - int sz = 2048; - Func prev("input"); - prev(x, y) = im(x, y); - - const int N = 4; - - for (int i = 0; i < N; i++) { - up[i] = Func("up" + std::to_string(i)); - down[i] = Func("down" + std::to_string(i)); - down[i](x, y) = prev(2 * x - 10, 2 * y - 10) + prev(2 * x + 10, 2 * y + 10); - prev = BoundaryConditions::repeat_edge(down[i], {{0, sz}, {0, sz}}); - // prev = down[i]; - sz /= 2; - } - - for (int i = N - 1; i >= 0; i--) { - up[i](x, y) = prev(x / 2 + 10, y / 2 + 10) + prev(x / 2 - 10, y / 2 - 10) + down[i](x, y); - prev = up[i]; - } - - Func out; - out(x, y) = up[0](x, y); - - out.set_estimate(x, 0, 2048).set_estimate(y, 0, 2048); - Pipeline(out).auto_schedule(target, params); - } - - if (true) { - ImageParam im(Float(32), 2); - - Func f("f"); - f(x, y) = im(x, y); - - Func scan("scan"); - scan(x, y) = f(x, y); - RDom r(1, 1999); - scan(x, r) += scan(x, r - 1); - scan(x, 1999 - r) += scan(x, 2000 - r); - Func casted("casted"); - casted(x, y) = scan(x, y); - - casted.set_estimate(x, 0, 2000).set_estimate(y, 0, 2000); - Pipeline(casted).auto_schedule(target, params); - } - - if (true) { - ImageParam im(Int(32), 2); - - Func f("f"), hist("hist"), output("output"); - Var i("i"); - f(x, y) = clamp(im(x, y), 0, 255); - RDom r(0, 2000, 0, 2000); - hist(i) = cast(0); - hist(f(r.x, r.y)) += cast(1); - output(i) = hist(i); - - f.set_estimate(x, 0, 2000).set_estimate(y, 0, 2000); - output.set_estimate(i, 0, 256); - Pipeline(output).auto_schedule(target, params); - } - - return 0; -} diff --git a/src/autoschedulers/adams2019/test.cpp b/src/autoschedulers/adams2019/test.cpp index 21e0f0ec20bb..8ee560260fa1 100644 --- a/src/autoschedulers/adams2019/test.cpp +++ b/src/autoschedulers/adams2019/test.cpp @@ -14,7 +14,13 @@ void set_env_variable(const std::string &name, const std::string &value, int ove #endif } -bool test_caching(Pipeline &p1, Pipeline &p2, const Target &target, const MachineParams ¶ms) { +std::string weights_path; + +bool test_caching(Pipeline &p1, Pipeline &p2, const Target &target) { + constexpr int parallelism = 32; +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + MachineParams params(parallelism, 16000000, 40); + static const std::string seed_value = Internal::get_env_variable("HL_SEED"); if (seed_value.empty()) { // If HL_SEED is not set, then set seed for both autoscheduling executions. @@ -25,11 +31,13 @@ bool test_caching(Pipeline &p1, Pipeline &p2, const Target &target, const Machin // Turn off caching. set_env_variable("HL_DISABLE_MEMOIZED_FEATURES", "1", /* overwrite */ 1); set_env_variable("HL_DISABLE_MEMOIZED_BLOCKS", "1", /* overwrite */ 1); + auto results_without_caching = p1.auto_schedule(target, params); // Turn on caching. set_env_variable("HL_DISABLE_MEMOIZED_FEATURES", "0", /* overwrite */ 1); set_env_variable("HL_DISABLE_MEMOIZED_BLOCKS", "0", /* overwrite */ 1); + auto results_with_caching = p2.auto_schedule(target, params); // Reset environment variables to what they were before (memoization variables are reset in main). @@ -37,6 +45,29 @@ bool test_caching(Pipeline &p1, Pipeline &p2, const Target &target, const Machin // Re-empty seed. set_env_variable("HL_SEED", "", /* overwrite */ 1); } +#else + int seed = (int)time(nullptr); + AutoschedulerParams params( + "Adams2019", + { + {"parallelism", std::to_string(parallelism)}, + {"random_dropout_seed", std::to_string(seed)}, + {"weights_path", weights_path}, + // Turn off caching. + {"disable_memoized_features", "1"}, + {"disable_memoized_blocks", "1"}, + }); + + // Turn off caching. + params.extra["disable_memoized_features"] = "1"; + params.extra["disable_memoized_blocks"] = "1"; + auto results_without_caching = p1.apply_autoscheduler(target, params); + + // Turn on caching. + params.extra["disable_memoized_features"] = "0"; + params.extra["disable_memoized_blocks"] = "0"; + auto results_with_caching = p2.apply_autoscheduler(target, params); +#endif // Compare calculated features. if (results_without_caching.featurization.size() != results_with_caching.featurization.size()) { @@ -54,18 +85,20 @@ bool test_caching(Pipeline &p1, Pipeline &p2, const Target &target, const Machin } int main(int argc, char **argv) { - if (argc != 2) { - fprintf(stderr, "Usage: %s \n", argv[0]); + if (argc != 3 || !strlen(argv[1]) || !strlen(argv[2])) { + fprintf(stderr, "Usage: %s \n", argv[0]); return 1; } load_plugin(argv[1]); + weights_path = argv[2]; +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API // Tests will mess with these environment variables, we save them in order to reset them later. const std::string cache_features = Internal::get_env_variable("HL_DISABLE_MEMOIZED_FEATURES"); const std::string cache_blocks = Internal::get_env_variable("HL_DISABLE_MEMOIZED_BLOCKS"); +#endif - MachineParams params(32, 16000000, 40); // Use a fixed target for the analysis to get consistent results from this test. Target target("x86-64-linux-sse41-avx-avx2"); @@ -90,7 +123,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on point-wise pipeline" << std::endl; return 1; } @@ -123,7 +156,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on huge expensive stencils and low memory costs" << std::endl; return 1; } @@ -149,7 +182,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on moderate isotropic stencils" << std::endl; return 1; } @@ -175,7 +208,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on smaller footprint stencil" << std::endl; return 1; } @@ -207,7 +240,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on stencil chain" << std::endl; return 1; } @@ -231,7 +264,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on an outer product" << std::endl; return 1; } @@ -263,7 +296,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on a separable downsample" << std::endl; return 1; } @@ -295,7 +328,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on Func with multiple stages + loops" << std::endl; return 1; } @@ -332,7 +365,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on scan with pointwise stages before and after" << std::endl; return 1; } @@ -365,7 +398,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on bad vectorization" << std::endl; return 1; } @@ -397,7 +430,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on matrix multiply + wrapper" << std::endl; return 1; } @@ -440,7 +473,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(pipeline1, pipeline2, target, params)) { + if (!test_caching(pipeline1, pipeline2, target)) { std::cerr << "Caching check failed on scan + downsample" << std::endl; return 1; } @@ -473,7 +506,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on gather with LUT" << std::endl; return 1; } @@ -501,7 +534,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on 'compute inside an rvar'" << std::endl; return 1; } @@ -529,7 +562,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on alternating vectorized dimensions" << std::endl; return 1; } @@ -560,7 +593,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on no-win scenario" << std::endl; return 1; } @@ -585,7 +618,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on boring memcpy" << std::endl; return 1; } @@ -609,7 +642,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on load from a tiny input image" << std::endl; return 1; } @@ -640,7 +673,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on many-dimension func" << std::endl; return 1; } @@ -673,7 +706,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on long transpose chain" << std::endl; return 1; } @@ -711,7 +744,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on inlines + stencil chain" << std::endl; return 1; } @@ -738,7 +771,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on alternating vectorized dimensions" << std::endl; return 1; } @@ -766,7 +799,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on vectorizable with pure var using RoundUp" << std::endl; return 1; } @@ -812,7 +845,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on convolution pyramid" << std::endl; return 1; } @@ -844,7 +877,7 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on casted scan" << std::endl; return 1; } @@ -874,15 +907,41 @@ int main(int argc, char **argv) { } } - if (!test_caching(p1, p2, target, params)) { + if (!test_caching(p1, p2, target)) { std::cerr << "Caching check failed on histogram" << std::endl; return 1; } } + // A trivial pipeline that just loads from a LUT + if (true) { + Pipeline p1; + Pipeline p2; + for (int test_condition = 0; test_condition < 2; test_condition++) { + Buffer lut(256); + Func f; + f(x) = lut(x); + + f.set_estimate(x, 0, 256); + + if (test_condition) { + p2 = Pipeline(f); + } else { + p1 = Pipeline(f); + } + } + + if (!test_caching(p1, p2, target)) { + std::cerr << "Caching check failed on stencil chain" << std::endl; + return 1; + } + } + +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API // Reset environment variables. set_env_variable("HL_DISABLE_MEMOIZED_FEATURES", cache_features, /* overwrite */ 1); set_env_variable("HL_DISABLE_MEMOIZED_BLOCKS", cache_blocks, /* overwrite */ 1); +#endif std::cout << "adams2019 testing passed\n"; return 0; diff --git a/src/autoschedulers/adams2019/test_function_dag.cpp b/src/autoschedulers/adams2019/test_function_dag.cpp index 933c7f9d5027..0b4604b9500d 100644 --- a/src/autoschedulers/adams2019/test_function_dag.cpp +++ b/src/autoschedulers/adams2019/test_function_dag.cpp @@ -1,3 +1,4 @@ +#include "Featurization.h" #include "FunctionDAG.h" #include "Halide.h" #include @@ -31,7 +32,7 @@ extern "C" int mul_by_two( return 0; } -void test_coeff_wise(const MachineParams ¶ms, const Target &target) { +void test_coeff_wise(const Target &target) { Var x("x"), y("y"); std::ostringstream with_extern; @@ -55,7 +56,7 @@ void test_coeff_wise(const MachineParams ¶ms, const Target &target) { h.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); std::vector v; v.push_back(h.function()); - Halide::Internal::Autoscheduler::FunctionDAG d(v, params, target); + Halide::Internal::Autoscheduler::FunctionDAG d(v, target); d.dump(with_extern); } @@ -70,13 +71,13 @@ void test_coeff_wise(const MachineParams ¶ms, const Target &target) { h.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); std::vector v; v.push_back(h.function()); - Halide::Internal::Autoscheduler::FunctionDAG d(v, params, target); + Halide::Internal::Autoscheduler::FunctionDAG d(v, target); d.dump(without_extern); } // Disabled for now: there is still work to do to populate the jacobian - //assert(with_extern.str() == without_extern.str()); + // assert(with_extern.str() == without_extern.str()); } extern "C" int matmul( @@ -113,7 +114,7 @@ extern "C" int matmul( return 0; } -void test_matmul(const MachineParams ¶ms, const Target &target) { +void test_matmul(const Target &target) { Var x("x"), y("y"), k("k"); RDom r(0, 200); Halide::Buffer input1(200, 200); @@ -140,7 +141,7 @@ void test_matmul(const MachineParams ¶ms, const Target &target) { h.set_estimate(x, 0, 200).set_estimate(y, 0, 200); std::vector v; v.push_back(h.function()); - Halide::Internal::Autoscheduler::FunctionDAG d(v, params, target); + Halide::Internal::Autoscheduler::FunctionDAG d(v, target); d.dump(with_extern); } @@ -153,7 +154,7 @@ void test_matmul(const MachineParams ¶ms, const Target &target) { h.set_estimate(x, 0, 200).set_estimate(y, 0, 200); std::vector v; v.push_back(h.function()); - Halide::Internal::Autoscheduler::FunctionDAG d(v, params, target); + Halide::Internal::Autoscheduler::FunctionDAG d(v, target); d.dump(without_extern); } @@ -164,11 +165,10 @@ void test_matmul(const MachineParams ¶ms, const Target &target) { int main(int argc, char **argv) { // Use a fixed target for the analysis to get consistent results from this test. - MachineParams params(32, 16000000, 40); Target target("x86-64-linux-sse41-avx-avx2"); - test_coeff_wise(params, target); - test_matmul(params, target); + test_coeff_wise(target); + test_matmul(target); return 0; } diff --git a/src/autoschedulers/adams2019/ASLog.cpp b/src/autoschedulers/common/ASLog.cpp similarity index 100% rename from src/autoschedulers/adams2019/ASLog.cpp rename to src/autoschedulers/common/ASLog.cpp diff --git a/src/autoschedulers/adams2019/ASLog.h b/src/autoschedulers/common/ASLog.h similarity index 82% rename from src/autoschedulers/adams2019/ASLog.h rename to src/autoschedulers/common/ASLog.h index f1a7da8fd5eb..5e0088b34d13 100644 --- a/src/autoschedulers/adams2019/ASLog.h +++ b/src/autoschedulers/common/ASLog.h @@ -5,6 +5,7 @@ // libHalide, so (despite the namespace) we are better off not // including Halide.h, lest we reference something we won't have available +#include #include #include #include @@ -28,6 +29,12 @@ class aslog { return *this; } + std::ostream &get_ostream() { + // It is an error to call this for an aslog() instance that cannot log. + assert(logging); + return std::cerr; + } + static int aslog_level(); }; diff --git a/src/autoschedulers/common/CMakeLists.txt b/src/autoschedulers/common/CMakeLists.txt index 90693889928b..c8afd5d2cdda 100644 --- a/src/autoschedulers/common/CMakeLists.txt +++ b/src/autoschedulers/common/CMakeLists.txt @@ -2,3 +2,11 @@ add_library(Halide_Plugin INTERFACE) add_library(Halide::Plugin ALIAS Halide_Plugin) target_include_directories(Halide_Plugin INTERFACE $) target_link_libraries(Halide_Plugin INTERFACE Halide::Halide) + +add_library(ASLog STATIC ASLog.cpp) +target_include_directories(ASLog PUBLIC $) + +# Sigh, header-only libraries shouldn't be special +add_library(ParamParser INTERFACE) +target_include_directories(ParamParser INTERFACE + $) diff --git a/src/autoschedulers/common/HalidePlugin.h b/src/autoschedulers/common/HalidePlugin.h index 7e6636bb09c0..c5ffb1e16ec9 100644 --- a/src/autoschedulers/common/HalidePlugin.h +++ b/src/autoschedulers/common/HalidePlugin.h @@ -11,4 +11,4 @@ } \ } register_##NAME; -#endif //HALIDE_HALIDEPLUGIN_H +#endif // HALIDE_HALIDEPLUGIN_H diff --git a/src/autoschedulers/common/ParamParser.h b/src/autoschedulers/common/ParamParser.h new file mode 100644 index 000000000000..25c943f6ce11 --- /dev/null +++ b/src/autoschedulers/common/ParamParser.h @@ -0,0 +1,74 @@ +#ifndef PARSE_H +#define PARSE_H + +#include "Errors.h" +#include +#include + +namespace Halide { +namespace Internal { +namespace Autoscheduler { + +class ParamParser { + std::map extra; + + // If the string can be parsed as a valid "T", set *value to it. + // If not, assert-fail. + template + static void parse_or_die(const std::string &str, T *value) { + std::istringstream iss(str); + T t; + // All one-byte ints int8 and uint8 should be parsed as integers, not chars -- + // including 'char' itself. (Note that sizeof(bool) is often-but-not-always-1, + // so be sure to exclude that case.) + if constexpr (sizeof(T) == sizeof(char) && !std::is_same::value) { + int i; + iss >> i; + t = (T)i; + } else { + iss >> t; + } + user_assert(!iss.fail() && iss.get() == EOF) << "Unable to parse: " << str; + *value = t; + } + +public: + explicit ParamParser(const std::map &m) + : extra(m) { + } + + // If the given key is present in m, parse the result into *value and return true. + // (If the string cannot be parsed as a valid "T", assert-fail.) + // If the given key is not present, leave *value untouched and return false. + template + bool parse(const std::string &key, T *value) { + auto it = extra.find(key); + if (it == extra.end()) { + return false; + } + parse_or_die(it->second, value); + extra.erase(it); + return true; + } + + void finish() { + if (!extra.empty()) { + std::ostringstream oss; + oss << "Autoscheduler Params contain unknown keys:\n"; + for (const auto &it : extra) { + oss << " " << it.first << "\n"; + } + user_error << oss.str(); + } + } + + ~ParamParser() { + finish(); + } +}; + +} // namespace Autoscheduler +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/autoschedulers/li2018/CMakeLists.txt b/src/autoschedulers/li2018/CMakeLists.txt index 809689012b35..f1d5b3c6f90a 100644 --- a/src/autoschedulers/li2018/CMakeLists.txt +++ b/src/autoschedulers/li2018/CMakeLists.txt @@ -1,4 +1,5 @@ add_autoscheduler(NAME Li2018 SOURCES GradientAutoscheduler.cpp) +target_link_libraries(Halide_Li2018 PRIVATE ParamParser) # ========================================================== # TODO(#4053): move these to a separate folder since they're tests. @@ -19,7 +20,7 @@ target_link_libraries(demo_gradient_autoscheduler PRIVATE demo_gradient Halide:: add_test(NAME demo_gradient_autoscheduler COMMAND demo_gradient_autoscheduler --benchmarks=all --benchmark_min_time=1 --estimate_all) -set_tests_properties(demo_gradient_autoscheduler PROPERTIES LABELS Li2018) +set_tests_properties(demo_gradient_autoscheduler PROPERTIES LABELS "Li2018;multithreaded;auto_schedule") ## @@ -30,7 +31,7 @@ if (BUILD_SHARED_LIBS) add_test(NAME gradient_autoscheduler_test_cpp COMMAND gradient_autoscheduler_test_cpp $) - set_tests_properties(gradient_autoscheduler_test_cpp PROPERTIES LABELS Li2018) + set_tests_properties(gradient_autoscheduler_test_cpp PROPERTIES LABELS "Li2018;auto_schedule") endif () ## @@ -44,18 +45,21 @@ if (WITH_PYTHON_BINDINGS) add_test(NAME gradient_autoscheduler_test_py COMMAND Python3::Interpreter "${CMAKE_CURRENT_SOURCE_DIR}/test.py") - set(PYTHONPATH "$>") + set( + PYTHONPATH + "$/.." + ) + list(TRANSFORM PYTHONPATH PREPEND "PYTHONPATH=path_list_prepend:") - if (WIN32) - set(SEP "\\$") - else () - set(SEP ":") - endif () + set( + PATH + "$" + "$" + ) + list(TRANSFORM PATH PREPEND "PATH=path_list_prepend:") - set(_PATH "$>;$>;$ENV{PATH}") - string(REPLACE ";" "${SEP}" _PATH "${_PATH}") set_tests_properties(gradient_autoscheduler_test_py PROPERTIES - LABELS Li2018 - ENVIRONMENT "PYTHONPATH=${PYTHONPATH};PATH=${_PATH}") + LABELS "Li2018;auto_schedule" + ENVIRONMENT_MODIFICATION "${PYTHONPATH};${PATH}") endif () endif () diff --git a/src/autoschedulers/li2018/GradientAutoscheduler.cpp b/src/autoschedulers/li2018/GradientAutoscheduler.cpp index b3c2cd93e233..204210e987c7 100644 --- a/src/autoschedulers/li2018/GradientAutoscheduler.cpp +++ b/src/autoschedulers/li2018/GradientAutoscheduler.cpp @@ -1,6 +1,7 @@ #include "Errors.h" #include "Halide.h" #include "HalidePlugin.h" +#include "ParamParser.h" namespace Halide { namespace Internal { @@ -8,6 +9,11 @@ namespace Autoscheduler { namespace { +struct GradientAutoschedulerParams { + /** Maximum level of parallelism available. */ + int parallelism = 16; +}; + std::map inference_bounds(const std::vector &functions, const std::vector &output_bounds) { std::vector funcs; @@ -86,7 +92,7 @@ int natural_vector_size(const Target &target, const Type &t) { template void parallelize_vars_and_rvars_gpu( - const MachineParams ¶ms, + const GradientAutoschedulerParams ¶ms, FuncOrStage func_or_stage, bool is_pure_def, const std::vector &vars, @@ -324,7 +330,7 @@ void parallelize_vars_and_rvars_gpu( template void parallelize_vars_and_rvars_cpu( - const MachineParams ¶ms, + const GradientAutoschedulerParams ¶ms, FuncOrStage func_or_stage, int natural_vector_size, bool is_pure_def, @@ -528,7 +534,7 @@ void parallelize_vars_and_rvars_cpu( template void parallelize_vars_and_rvars( - const MachineParams ¶ms, + const GradientAutoschedulerParams ¶ms, FuncOrStage func_or_stage, int natural_vector_size, bool is_pure_def, @@ -565,7 +571,7 @@ void parallelize_vars_and_rvars( } } -void apply_schedule(const MachineParams ¶ms, +void apply_schedule(const GradientAutoschedulerParams ¶ms, const Target &target, Func func, int update_id, @@ -606,10 +612,6 @@ void apply_schedule(const MachineParams ¶ms, for (const ReductionVariable &r : reduction_vars) { rvars.emplace_back(r.var); } - int rdomain_size = 1; - for (int b : rvar_bounds) { - rdomain_size *= b; - } // Define the thresholds for the pure domain. // For CPU we want at least params.parallelism number of elements // to launch threads. For GPU we want to launch at least 64 GPU blocks. @@ -821,7 +823,7 @@ void apply_schedule(const MachineParams ¶ms, void generate_schedule(const std::vector &outputs, const Target &target, - const MachineParams ¶ms, + const GradientAutoschedulerParams ¶ms, AutoSchedulerResults *auto_scheduler_results) { // The first few steps are the same as src/AutoSchedule.cpp // Make an environment map which is used throughout the auto scheduling process. @@ -923,19 +925,42 @@ void generate_schedule(const std::vector &outputs, } } +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API auto_scheduler_results->scheduler_name = "Li2018"; +#endif auto_scheduler_results->schedule_source = schedule_source.str(); debug(1) << schedule_source.str() << "\n"; } struct Li2018 { - void operator()(const Pipeline &p, const Target &target, const MachineParams ¶ms, AutoSchedulerResults *results) { +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + void operator()(const Pipeline &p, const Target &target, const MachineParams ¶ms_in, AutoSchedulerResults *results) { std::vector outputs; for (const Func &f : p.outputs()) { outputs.push_back(f.function()); } + GradientAutoschedulerParams params; + params.parallelism = params_in.parallelism; + generate_schedule(outputs, target, params, results); + } +#else + void operator()(const Pipeline &p, const Target &target, const AutoschedulerParams ¶ms_in, AutoSchedulerResults *results) { + internal_assert(params_in.name == "Li2018"); + + std::vector outputs; + for (const Func &f : p.outputs()) { + outputs.push_back(f.function()); + } + GradientAutoschedulerParams params; + { + ParamParser parser(params_in.extra); + parser.parse("parallelism", ¶ms.parallelism); + parser.finish(); + } generate_schedule(outputs, target, params, results); + results->autoscheduler_params = params_in; } +#endif }; REGISTER_AUTOSCHEDULER(Li2018) diff --git a/src/autoschedulers/li2018/Makefile b/src/autoschedulers/li2018/Makefile index db11cd90768e..715477190dbf 100644 --- a/src/autoschedulers/li2018/Makefile +++ b/src/autoschedulers/li2018/Makefile @@ -18,12 +18,15 @@ else HALIDE_RPATH_FOR_LIB += '-Wl,-rpath,$$ORIGIN' endif -$(BIN)/libautoschedule_li2018.$(SHARED_EXT): $(SRC)/GradientAutoscheduler.cpp $(LIB_HALIDE) +# Be sure *not* to include libHalide in the link steps here; that can cause misbehavior +# on OSX systems in certain situations -- note that $(LIB_HALIDE) is an order-only dep, +# to ensure that (eg) Halide.h is built before this. +$(BIN)/libautoschedule_li2018.$(PLUGIN_EXT): $(SRC)/GradientAutoscheduler.cpp | $(LIB_HALIDE) @mkdir -p $(@D) $(CXX) -shared $(USE_EXPORT_DYNAMIC) -fPIC -fvisibility=hidden -fvisibility-inlines-hidden $(CXXFLAGS) $(OPTIMIZE) $^ -o $@ $(HALIDE_SYSTEM_LIBS) $(HALIDE_RPATH_FOR_LIB) # Demonstrate a JIT-based use of gradient autoscheuler -$(BIN)/test: $(SRC)/test.cpp $(BIN)/libautoschedule_li2018.$(SHARED_EXT) +$(BIN)/test: $(SRC)/test.cpp $(BIN)/libautoschedule_li2018.$(PLUGIN_EXT) @mkdir -p $(@D) $(CXX) $(CXXFLAGS) $(USE_EXPORT_DYNAMIC) $(SRC)/test.cpp -o $@ $(LIBHALIDE_LDFLAGS) $(HALIDE_SYSTEM_LIBS) @@ -33,31 +36,26 @@ $(GENERATOR_BIN)/demo.generator: $(SRC)/demo_generator.cpp $(GENERATOR_DEPS) $(CXX) $(CXXFLAGS) $(USE_EXPORT_DYNAMIC) -g $(filter-out %.h,$^) -o $@ $(LIBHALIDE_LDFLAGS) $(HALIDE_SYSTEM_LIBS) # Use the -p flag to the generator to load the autoscheduler as a plugin -$(BIN)/%/demo.a: $(GENERATOR_BIN)/demo.generator $(BIN)/libautoschedule_li2018.$(SHARED_EXT) +$(BIN)/%/demo.a: $(GENERATOR_BIN)/demo.generator $(BIN)/libautoschedule_li2018.$(PLUGIN_EXT) @mkdir -p $(@D) - $(GENERATOR_BIN)/demo.generator -g demo -o $(@D) -f demo target=$* auto_schedule=true -p $(BIN)/libautoschedule_li2018.$(SHARED_EXT) -s Li2018 + $(GENERATOR_BIN)/demo.generator -g demo -o $(@D) -f demo target=$* autoscheduler=Li2018 -p $(BIN)/libautoschedule_li2018.$(PLUGIN_EXT) $(BIN)/%/demo.rungen: $(BIN)/%/RunGenMain.o $(BIN)/%/demo.registration.cpp $(BIN)/%/demo.a @mkdir -p $(@D) $(CXX) $(CXXFLAGS) -I$(BIN)/$* $^ -o $@ $(HALIDE_SYSTEM_LIBS) $(IMAGE_IO_FLAGS) -.PHONY: build test clean run_test_cpp run_test_py test_generator +.PHONY: build test clean run_test_cpp test_generator # demonstrates single-shot use of the autoscheduler -test_generator: $(BIN)/$(HL_TARGET)/demo.rungen $(BIN)/libautoschedule_li2018.$(SHARED_EXT) +test_generator: $(BIN)/$(HL_TARGET)/demo.rungen $(BIN)/libautoschedule_li2018.$(PLUGIN_EXT) $< --benchmarks=all --benchmark_min_time=1 --estimate_all run_test_cpp: $(BIN)/test - LD_LIBRARY_PATH=$(BIN) $< $(BIN)/libautoschedule_li2018.$(SHARED_EXT) + LD_LIBRARY_PATH=$(BIN) $< $(BIN)/libautoschedule_li2018.$(PLUGIN_EXT) -run_test_py: $(SRC)/test.py $(BIN)/libautoschedule_li2018.$(SHARED_EXT) - PYTHONPATH=$(BIN):$(HALIDE_PYTHON_BINDINGS_PATH):$(HALIDE_DISTRIB_PATH)/bin:$$PYTHONPATH \ - LD_LIBRARY_PATH=$(BIN):$(HALIDE_PYTHON_BINDINGS_PATH):$(HALIDE_DISTRIB_PATH)/bin \ - $(PYTHON) $(SRC)/test.py +build: $(BIN)/test $(BIN)/$(HL_TARGET)/demo.rungen $(BIN)/libautoschedule_li2018.$(PLUGIN_EXT) -\build: $(BIN)/test $(BIN)/$(HL_TARGET)/demo.rungen $(BIN)/libautoschedule_li2018.$(SHARED_EXT) - -test: run_test_cpp run_test_py test_generator +test: run_test_cpp test_generator clean: rm -rf $(BIN) diff --git a/src/autoschedulers/li2018/test.cpp b/src/autoschedulers/li2018/test.cpp index 6518cda38960..f3fb11f7cca7 100644 --- a/src/autoschedulers/li2018/test.cpp +++ b/src/autoschedulers/li2018/test.cpp @@ -10,7 +10,13 @@ int main(int argc, char **argv) { load_plugin(argv[1]); - MachineParams params(32, 16000000, 40); + constexpr int parallelism = 32; +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + MachineParams params(parallelism, 16000000, 40); +#else + AutoschedulerParams params = {"Li2018", {{"parallelism", std::to_string(parallelism)}}}; +#endif + Target target; Var x("x"), y("y"); @@ -27,8 +33,11 @@ int main(int argc, char **argv) { f2.set_estimate(x, 0, 10000); - AutoSchedulerResults result = - Pipeline(f2).auto_schedule(target, params); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + AutoSchedulerResults result = Pipeline(f2).auto_schedule(target, params); +#else + AutoSchedulerResults result = Pipeline(f2).apply_autoscheduler(target, params); +#endif std::cout << "Schedule for 1D pointwise operations:\n" << result.schedule_source << "\n\n"; } @@ -46,8 +55,11 @@ int main(int argc, char **argv) { f2.set_estimate(x, 0, 1000) .set_estimate(y, 0, 1000); - AutoSchedulerResults result = - Pipeline(f2).auto_schedule(target, params); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + AutoSchedulerResults result = Pipeline(f2).auto_schedule(target, params); +#else + AutoSchedulerResults result = Pipeline(f2).apply_autoscheduler(target, params); +#endif std::cout << "Schedule for 2D pointwise operations:\n" << result.schedule_source << "\n\n"; } @@ -61,8 +73,11 @@ int main(int argc, char **argv) { f0.set_estimate(x, 0, 1000); - AutoSchedulerResults result = - Pipeline(f0).auto_schedule(target, params); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + AutoSchedulerResults result = Pipeline(f0).auto_schedule(target, params); +#else + AutoSchedulerResults result = Pipeline(f0).apply_autoscheduler(target, params); +#endif std::cout << "Schedule for 1D convolution:\n" << result.schedule_source << "\n\n"; } @@ -77,8 +92,11 @@ int main(int argc, char **argv) { f0.set_estimate(x, 0, 1000) .set_estimate(y, 0, 1000); - AutoSchedulerResults result = - Pipeline(f0).auto_schedule(target, params); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + AutoSchedulerResults result = Pipeline(f0).auto_schedule(target, params); +#else + AutoSchedulerResults result = Pipeline(f0).apply_autoscheduler(target, params); +#endif std::cout << "Schedule for 2D convolution:\n" << result.schedule_source << "\n\n"; } @@ -93,8 +111,11 @@ int main(int argc, char **argv) { hist.set_estimate(x, 0, 10); - AutoSchedulerResults result = - Pipeline(hist).auto_schedule(target, params); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + AutoSchedulerResults result = Pipeline(hist).auto_schedule(target, params); +#else + AutoSchedulerResults result = Pipeline(hist).apply_autoscheduler(target, params); +#endif std::cout << "Schedule for 1D histogram:\n" << result.schedule_source << "\n\n"; } @@ -109,8 +130,11 @@ int main(int argc, char **argv) { hist.set_estimate(x, 0, 10); - AutoSchedulerResults result = - Pipeline(hist).auto_schedule(target, params); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + AutoSchedulerResults result = Pipeline(hist).auto_schedule(target, params); +#else + AutoSchedulerResults result = Pipeline(hist).apply_autoscheduler(target, params); +#endif std::cout << "Schedule for 2D histogram:\n" << result.schedule_source << "\n\n"; } @@ -125,8 +149,11 @@ int main(int argc, char **argv) { hist.set_estimate(x, 0, 10000); - AutoSchedulerResults result = - Pipeline(hist).auto_schedule(target, params); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + AutoSchedulerResults result = Pipeline(hist).auto_schedule(target, params); +#else + AutoSchedulerResults result = Pipeline(hist).apply_autoscheduler(target, params); +#endif std::cout << "Schedule for 2D histogram with larger domain:\n" << result.schedule_source << "\n\n"; } @@ -146,8 +173,11 @@ int main(int argc, char **argv) { f2.set_estimate(y, 0, 1024) .set_estimate(x, 0, 4); - AutoSchedulerResults result = - Pipeline(f2).auto_schedule(target, params); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + AutoSchedulerResults result = Pipeline(f2).auto_schedule(target, params); +#else + AutoSchedulerResults result = Pipeline(f2).apply_autoscheduler(target, params); +#endif std::cout << "Schedule for 2D pointwise operations with small x dimension:\n" << result.schedule_source << "\n\n"; } diff --git a/src/autoschedulers/li2018/test.py b/src/autoschedulers/li2018/test.py index 30415ee9c5f6..c21003203c18 100644 --- a/src/autoschedulers/li2018/test.py +++ b/src/autoschedulers/li2018/test.py @@ -17,9 +17,8 @@ def main(): f_2.set_estimate(x, 0, 1000) p = hl.Pipeline(f_2) target = hl.Target() - # Only first parameter is used (number of cores on CPU) - params = hl.MachineParams(32, 0, 0); - result = p.auto_schedule('Li2018', target, params) + asp = hl.AutoschedulerParams('Li2018', {'parallelism': 32}) + result = p.apply_autoscheduler(target, asp) print('Schedule:') print(result.schedule_source) print('Python Schedule:') diff --git a/src/autoschedulers/mcts/AutoSchedule.cpp b/src/autoschedulers/mcts/AutoSchedule.cpp index 46e4c8ebc6cb..3588edefb729 100644 --- a/src/autoschedulers/mcts/AutoSchedule.cpp +++ b/src/autoschedulers/mcts/AutoSchedule.cpp @@ -31,7 +31,7 @@ Write out a training featurization for the selected schedule into this file. Needs to be converted to a sample file with the runtime using featurization_to_sample before it can be used to train. - HL_MACHINE_PARAMS + MctsParams An architecture description string. Used by Halide master to configure the cost model. We only use the first term. Set it to the number of cores to target. HL_PERMIT_FAILED_UNROLL @@ -39,7 +39,7 @@ HL_SCHEDULE_FILE *** DEPRECATED *** use the 'schedule' output from Generator instead - Write out a human-and-machine readable block of scheduling source code for the selected schedule into this file. + Write out a human-and-machine MctsParams block of scheduling source code for the selected schedule into this file. HL_RANDOM_DROPOUT percent chance of accepting each state in the beam. Normalized by the number of decisions made, so 5 would be there's a 5 percent chance of never rejecting any states. @@ -88,11 +88,12 @@ #include "LoopNest.h" #include "NetworkSize.h" #include "PerfectHashMap.h" +#include "ParamParser.h" -#include "MCTS.h" #include "CPU_State.h" -#include "Timer.h" #include "CostPrinter.h" +#include "MCTS.h" +#include "Timer.h" #ifdef _WIN32 #include @@ -100,104 +101,104 @@ #endif namespace MCTS { - uint32_t get_dropout_threshold() { - std::string random_dropout_str = Halide::Internal::get_env_variable("HL_RANDOM_DROPOUT"); - if (!random_dropout_str.empty()) { - return atoi(random_dropout_str.c_str()); - } else { - return 100; - } +uint32_t get_dropout_threshold() { + std::string random_dropout_str = Halide::Internal::get_env_variable("HL_RANDOM_DROPOUT"); + if (!random_dropout_str.empty()) { + return atoi(random_dropout_str.c_str()); + } else { + return 100; } +} - bool random_dropout(std::mt19937 &rng, size_t num_decisions) { - static double random_dropout_threshold = get_dropout_threshold(); - if (random_dropout_threshold >= 100) { - return false; - } +bool random_dropout(std::mt19937 &rng, size_t num_decisions) { + static double random_dropout_threshold = get_dropout_threshold(); + if (random_dropout_threshold >= 100) { + return false; + } - // The random dropout threshold is the chance that we operate - // entirely greedily and never discard anything. - double t = random_dropout_threshold; - t /= 100; - t = std::pow(t, 1.0f / num_decisions); - t *= 100; + // The random dropout threshold is the chance that we operate + // entirely greedily and never discard anything. + double t = random_dropout_threshold; + t /= 100; + t = std::pow(t, 1.0f / num_decisions); + t *= 100; - uint32_t r = rng(); - bool drop_it = (r % 100) >= t; - return drop_it; - } + uint32_t r = rng(); + bool drop_it = (r % 100) >= t; + return drop_it; +} - double get_exploration_percent() { - std::string exploration_str = Halide::Internal::get_env_variable("HL_MCTS_EXPLORATION"); - if (!exploration_str.empty()) { - return std::stod(exploration_str.c_str()); - } else { - return .025; - } +double get_exploration_percent() { + std::string exploration_str = Halide::Internal::get_env_variable("HL_MCTS_EXPLORATION"); + if (!exploration_str.empty()) { + return std::stod(exploration_str.c_str()); + } else { + return .025; } +} - double get_exploitation_percent() { - std::string exploitation_str = Halide::Internal::get_env_variable("HL_MCTS_EXPLOITATION"); - if (!exploitation_str.empty()) { - return std::stod(exploitation_str.c_str()); - } else { - return .025; - } +double get_exploitation_percent() { + std::string exploitation_str = Halide::Internal::get_env_variable("HL_MCTS_EXPLOITATION"); + if (!exploitation_str.empty()) { + return std::stod(exploitation_str.c_str()); + } else { + return .025; } +} - uint32_t get_min_explore() { - std::string min_iters_str = Halide::Internal::get_env_variable("HL_MCTS_EXPLORE_MIN"); - if (!min_iters_str.empty()) { - return atoi(min_iters_str.c_str()); - } else { - return 4; - } +uint32_t get_min_explore() { + std::string min_iters_str = Halide::Internal::get_env_variable("HL_MCTS_EXPLORE_MIN"); + if (!min_iters_str.empty()) { + return atoi(min_iters_str.c_str()); + } else { + return 4; } +} - uint32_t get_min_exploit() { - std::string min_iters_str = Halide::Internal::get_env_variable("HL_MCTS_EXPLOIT_MIN"); - if (!min_iters_str.empty()) { - return atoi(min_iters_str.c_str()); - } else { - return 4; - } +uint32_t get_min_exploit() { + std::string min_iters_str = Halide::Internal::get_env_variable("HL_MCTS_EXPLOIT_MIN"); + if (!min_iters_str.empty()) { + return atoi(min_iters_str.c_str()); + } else { + return 4; } +} - uint32_t get_rollout_length() { - std::string rollout_str = Halide::Internal::get_env_variable("HL_MCTS_ROLLOUT_LENGTH"); - if (!rollout_str.empty()) { - return atoi(rollout_str.c_str()); - } else { - return 4; - } +uint32_t get_rollout_length() { + std::string rollout_str = Halide::Internal::get_env_variable("HL_MCTS_ROLLOUT_LENGTH"); + if (!rollout_str.empty()) { + return atoi(rollout_str.c_str()); + } else { + return 4; } +} - uint32_t get_beam_size() { - std::string beam_str = Halide::Internal::get_env_variable("HL_MCTS_BEAM_SIZE"); - if (!beam_str.empty()) { - return atoi(beam_str.c_str()); - } else { - return 4; - } +uint32_t get_beam_size() { + std::string beam_str = Halide::Internal::get_env_variable("HL_MCTS_BEAM_SIZE"); + if (!beam_str.empty()) { + return atoi(beam_str.c_str()); + } else { + return 4; } +} - bool use_beam() { - std::string beam_str = Halide::Internal::get_env_variable("HL_MCTS_DISABLE_BEAM"); - return beam_str != "1"; - } +bool use_beam() { + std::string beam_str = Halide::Internal::get_env_variable("HL_MCTS_DISABLE_BEAM"); + return beam_str != "1"; +} - void print_env_variables() { - // TODO: add to this if we add to the variables above - std::cerr << "export HL_RANDOM_DROPOUT=" << get_dropout_threshold() << "; "; - std::cerr << "export HL_MCTS_EXPLORATION=" << get_exploration_percent() << "; "; - std::cerr << "export HL_MCTS_EXPLOITATION=" << get_exploitation_percent() << "; "; - std::cerr << "export HL_MCTS_EXPLORE_MIN=" << get_min_explore() << "; "; - std::cerr << "export HL_MCTS_EXPLOIT_MIN=" << get_min_exploit() << "; "; - std::cerr << "export HL_MCTS_ROLLOUT_LENGTH=" << get_rollout_length() << "; "; - std::cerr << "export HL_MCTS_BEAM_SIZE=" << get_beam_size() << "; "; - std::cerr << "export HL_MCTS_DISABLE_BEAM=" << !use_beam() << ";\n"; - } +void print_env_variables() { + // TODO: add to this if we add to the variables above + std::cerr << "export HL_RANDOM_DROPOUT=" << get_dropout_threshold() << "; "; + std::cerr << "export HL_MCTS_EXPLORATION=" << get_exploration_percent() << "; "; + std::cerr << "export HL_MCTS_EXPLOITATION=" << get_exploitation_percent() << "; "; + std::cerr << "export HL_MCTS_EXPLORE_MIN=" << get_min_explore() << "; "; + std::cerr << "export HL_MCTS_EXPLOIT_MIN=" << get_min_exploit() << "; "; + std::cerr << "export HL_MCTS_ROLLOUT_LENGTH=" << get_rollout_length() << "; "; + std::cerr << "export HL_MCTS_BEAM_SIZE=" << get_beam_size() << "; "; + std::cerr << "export HL_MCTS_DISABLE_BEAM=" << !use_beam() << ";\n"; } +} // namespace MCTS namespace Halide { namespace Internal { @@ -253,17 +254,16 @@ struct ProgressBar { // Configure a cost model to process a specific pipeline. void configure_pipeline_features(const FunctionDAG &dag, - const MachineParams ¶ms, + const MctsParams ¶ms, CostModel *cost_model) { cost_model->reset(); cost_model->set_pipeline_features(dag, params); } - // The main entrypoint to generate a schedule for a pipeline. void generate_schedule(const std::vector &outputs, const Target &target, - const MachineParams ¶ms, + const MctsParams ¶ms, AutoSchedulerResults *auto_scheduler_results) { aslog(0) << "generate_schedule for target=" << target.to_string() << "\n"; @@ -290,7 +290,7 @@ void generate_schedule(const std::vector &outputs, int64_t memory_limit = memory_limit_str.empty() ? (uint64_t)(-1) : std::atoll(memory_limit_str.c_str()); // Analyse the Halide algorithm and construct our abstract representation of it - FunctionDAG dag(outputs, params, target); + FunctionDAG dag(outputs, target); if (aslog::aslog_level() > 0) { dag.dump(); } @@ -315,15 +315,15 @@ void generate_schedule(const std::vector &outputs, CPU_State start_state(&dag, ¶ms, cost_model.get(), root, /* n_decisions */ 0, memory_limit); aslog(0) << "Starting\n"; MCTS::state_count = 0; - std::shared_ptr > best_action = nullptr; + std::shared_ptr> best_action = nullptr; double cost = 0.0f; std::string schedule_source; std::string python_schedule_source; try { - CPU_State optimal = (MCTS::use_beam()) ? - solver.solve_beam(start_state, /* n_decisions*/ dag.nodes.size() * 2, seed) - : solver.solve(start_state, /* n_decisions*/ dag.nodes.size() * 2, seed); + CPU_State optimal = (MCTS::use_beam()) ? + solver.solve_beam(start_state, /* n_decisions*/ dag.nodes.size() * 2, seed) : + solver.solve(start_state, /* n_decisions*/ dag.nodes.size() * 2, seed); cost = optimal.calculate_cost(); schedule_source = optimal.apply_schedule(python_schedule_source); std::cerr << "is_terminal? " << optimal.is_terminal() << std::endl; @@ -347,7 +347,7 @@ void generate_schedule(const std::vector &outputs, } if (auto_scheduler_results) { - auto_scheduler_results->scheduler_name = "mcts"; + auto_scheduler_results->autoscheduler_params.name = "mcts"; auto_scheduler_results->schedule_source = schedule_source; auto_scheduler_results->python_schedule_source = python_schedule_source; { @@ -358,7 +358,7 @@ void generate_schedule(const std::vector &outputs, } } - } catch (const std::bad_alloc& e) { + } catch (const std::bad_alloc &e) { std::cerr << "Allocation failed: " << e.what() << std::endl; } catch (...) { std::cerr << "Some other exception?" << std::endl; @@ -367,7 +367,7 @@ void generate_schedule(const std::vector &outputs, std::chrono::duration total_time = timer.elapsed(); auto milli = std::chrono::duration_cast(total_time).count(); - //aslog(0) << "Found Pipeline with cost: " << cost << "\n"; + // aslog(0) << "Found Pipeline with cost: " << cost << "\n"; aslog(0) << "Best cost: " << cost << "\n"; aslog(0) << "Execution time: " << milli << " ms\n\n"; @@ -381,9 +381,9 @@ void generate_schedule(const std::vector &outputs, user_warning << "HL_SCHEDULE_FILE is deprecated; use the schedule output from Generator instead\n"; aslog(1) << "Writing schedule to " << schedule_file << "...\n"; std::ofstream f(schedule_file); - f << "// --- BEGIN machine-generated schedule\n" + f << "// --- BEGIN machine-MctsParams schedule\n" << schedule_source - << "// --- END machine-generated schedule\n"; + << "// --- END machine-MctsParams schedule\n"; f.close(); internal_assert(!f.fail()) << "Failed to write " << schedule_file; } @@ -393,22 +393,29 @@ void generate_schedule(const std::vector &outputs, user_warning << "HL_PYTHON_SCHEDULE_FILE is deprecated; use the schedule output from Generator instead\n"; aslog(1) << "Writing schedule to " << python_schedule_file << "...\n"; std::ofstream f(python_schedule_file); - f << "# --- BEGIN machine-generated schedule\n" + f << "# --- BEGIN machine-MctsParams schedule\n" << python_schedule_source - << "# --- END machine-generated schedule\n"; + << "# --- END machine-MctsParams schedule\n"; f.close(); internal_assert(!f.fail()) << "Failed to write " << python_schedule_file; } - } struct mcts { - void operator()(const Pipeline &p, const Target &target, const MachineParams ¶ms, AutoSchedulerResults *results) { + void operator()(const Pipeline &p, const Target &target, const AutoschedulerParams ¶ms_in, AutoSchedulerResults *results) { std::vector outputs; for (const Func &f : p.outputs()) { outputs.push_back(f.function()); } - Autoscheduler::generate_schedule(outputs, target, params, results); + auto params = MctsParams::generic(); + { + ParamParser parser(params_in.extra); + parser.parse("parallelism", ¶ms.parallelism); + parser.parse("last_level_cache_size", ¶ms.last_level_cache_size); + parser.parse("balance", ¶ms.balance); + parser.finish(); + } + Autoscheduler::generate_schedule(outputs, target, MctsParams::generic(), results); } }; @@ -417,7 +424,7 @@ REGISTER_AUTOSCHEDULER(mcts) // An alternative entrypoint for other uses // void find_and_apply_schedule(FunctionDAG &dag, // const std::vector &outputs, -// const MachineParams ¶ms, +// const MctsParams ¶ms, // CostModel *cost_model, // int beam_size, // int64_t memory_limit, diff --git a/src/autoschedulers/mcts/AutoSchedule.h b/src/autoschedulers/mcts/AutoSchedule.h index b7a76dc67e50..88f3b1579e02 100644 --- a/src/autoschedulers/mcts/AutoSchedule.h +++ b/src/autoschedulers/mcts/AutoSchedule.h @@ -11,7 +11,7 @@ namespace Autoscheduler { typedef PerfectHashMap StageMapOfScheduleFeatures; -void find_and_apply_schedule(FunctionDAG &dag, const std::vector &outputs, const MachineParams ¶ms, +void find_and_apply_schedule(FunctionDAG &dag, const std::vector &outputs, const MctsParams ¶ms, CostModel *cost_model, int beam_size, StageMapOfScheduleFeatures *schedule_features); } // namespace Autoscheduler diff --git a/src/autoschedulers/mcts/CPU_State.cpp b/src/autoschedulers/mcts/CPU_State.cpp index 56b25fef73e4..9c4d0527eaac 100644 --- a/src/autoschedulers/mcts/CPU_State.cpp +++ b/src/autoschedulers/mcts/CPU_State.cpp @@ -10,7 +10,7 @@ namespace Halide { namespace Internal { namespace Autoscheduler { -// void compute_featurization(const FunctionDAG *dag, const MachineParams *params, StageMap *features) { +// void compute_featurization(const FunctionDAG *dag, const MctsParams *params, StageMap *features) { // } @@ -682,7 +682,7 @@ void CPU_State::dump() const { } // This code is taken from the State::calculate_cost() code in the adams2019 autoscheduler. -bool prunable(const FunctionDAG *dag_ptr, const MachineParams *params_ptr, const LoopNest *root_ptr, StageMap &features, int64_t memory_limit) { +bool prunable(const FunctionDAG *dag_ptr, const MctsParams *params_ptr, const LoopNest *root_ptr, StageMap &features, int64_t memory_limit) { compute_featurization(dag_ptr, params_ptr, root_ptr, &features); // TODO(rootjalex): add a verbose dump @@ -721,7 +721,7 @@ bool prunable(const FunctionDAG *dag_ptr, const MachineParams *params_ptr, const } // This is directly taken from State::compute_featurization. -void compute_featurization(const FunctionDAG *dag_ptr, const MachineParams *params_ptr, const LoopNest *root_ptr, StageMap *features) { +void compute_featurization(const FunctionDAG *dag_ptr, const MctsParams *params_ptr, const LoopNest *root_ptr, StageMap *features) { StageMap sites; sites.make_large(dag_ptr->nodes[0].stages[0].max_id); features->make_large(dag_ptr->nodes[0].stages[0].max_id); @@ -793,7 +793,7 @@ void compute_featurization(const FunctionDAG *dag_ptr, const MachineParams *para } // This is directly taken from State::save_featurization. -void save_featurization(const FunctionDAG *dag_ptr, const MachineParams *params_ptr, const LoopNest *root_ptr, std::ostream &out) { +void save_featurization(const FunctionDAG *dag_ptr, const MctsParams *params_ptr, const LoopNest *root_ptr, std::ostream &out) { StageMap features; compute_featurization(dag_ptr, params_ptr, root_ptr, &features); diff --git a/src/autoschedulers/mcts/CPU_State.h b/src/autoschedulers/mcts/CPU_State.h index ca415899ed95..33702f72700b 100644 --- a/src/autoschedulers/mcts/CPU_State.h +++ b/src/autoschedulers/mcts/CPU_State.h @@ -123,7 +123,7 @@ class CPU_State { // TODO(rootjalex): should these be static members then? public: const FunctionDAG *dag_ptr; - const MachineParams *params_ptr; + const MctsParams *params_ptr; CostModel *model_ptr; int64_t memory_limit = 0; private: @@ -146,7 +146,7 @@ class CPU_State { CPU_State(const CPU_State &_state) = default; CPU_State(CPU_State &&_state) = default; CPU_State &operator=(const CPU_State &_state) = default; - CPU_State(const FunctionDAG *_dag_ptr, const MachineParams *_params_ptr, + CPU_State(const FunctionDAG *_dag_ptr, const MctsParams *_params_ptr, CostModel *_model_ptr, IntrusivePtr _root, int n_decisions, int64_t _memory_limit = 0) : root(_root), n_decisions_made(n_decisions), dag_ptr(_dag_ptr), params_ptr(_params_ptr), model_ptr(_model_ptr), memory_limit(_memory_limit) { @@ -204,13 +204,13 @@ class CPU_State { // This is used to early-out for certain prunable States. // Returns true if this LoopNest should not be a valid State. -bool prunable(const FunctionDAG *dag_ptr, const MachineParams *params_ptr, const LoopNest *root_ptr, StageMap &features, int64_t memory_limit); +bool prunable(const FunctionDAG *dag_ptr, const MctsParams *params_ptr, const LoopNest *root_ptr, StageMap &features, int64_t memory_limit); // Used by the above to check if a LoopNest is prunable. -void compute_featurization(const FunctionDAG *dag_ptr, const MachineParams *params_ptr, const LoopNest *root_ptr, StageMap *features); +void compute_featurization(const FunctionDAG *dag_ptr, const MctsParams *params_ptr, const LoopNest *root_ptr, StageMap *features); // Calls `compute_featurization` and prints those features to `out`. -void save_featurization(const FunctionDAG *dag_ptr, const MachineParams *params_ptr, const LoopNest *root_ptr, std::ostream &out); +void save_featurization(const FunctionDAG *dag_ptr, const MctsParams *params_ptr, const LoopNest *root_ptr, std::ostream &out); } // namespace Autoscheduler } // namespace Internal diff --git a/src/autoschedulers/mcts/CostModel.h b/src/autoschedulers/mcts/CostModel.h index 8459932c8dca..2833f2e3c546 100644 --- a/src/autoschedulers/mcts/CostModel.h +++ b/src/autoschedulers/mcts/CostModel.h @@ -13,6 +13,49 @@ namespace Halide { namespace Internal { namespace Autoscheduler { typedef PerfectHashMap StageMapOfScheduleFeatures; + +/** A struct representing the machine parameters to generate the auto-scheduled + * code for. */ +struct MctsParams { + /** Maximum level of parallelism avalaible. */ + int parallelism; + /** Size of the last-level cache (in bytes). */ + uint64_t last_level_cache_size; + /** Indicates how much more expensive is the cost of a load compared to + * the cost of an arithmetic operation at last level cache. */ + float balance; + + explicit MctsParams(int parallelism, uint64_t llc, float balance) + : parallelism(parallelism), last_level_cache_size(llc), balance(balance) { + } + + /** Default machine parameters for generic CPU architecture. */ + static MctsParams generic() { + std::string params = Internal::get_env_variable("HL_MACHINE_PARAMS"); + if (params.empty()) { + return MctsParams(16, 16 * 1024 * 1024, 40); + } else { + return MctsParams(params); + } + } + + /** Convert the MctsParams into canonical string form. */ + std::string to_string() const { + std::ostringstream o; + o << parallelism << "," << last_level_cache_size << "," << balance; + return o.str(); + } + + /** Reconstruct a MctsParams from canonical string form. */ + explicit MctsParams(const std::string &s) { + std::vector v = Internal::split_string(s, ","); + user_assert(v.size() == 3) << "Unable to parse MctsParams: " << s; + parallelism = std::atoi(v[0].c_str()); + last_level_cache_size = std::atoll(v[1].c_str()); + balance = std::atof(v[2].c_str()); + } +}; + } // namespace Autoscheduler } // namespace Internal @@ -22,7 +65,7 @@ class CostModel { // Configure the cost model for the algorithm to be scheduled. virtual void set_pipeline_features(const Internal::Autoscheduler::FunctionDAG &dag, - const MachineParams ¶ms) = 0; + const Internal::Autoscheduler::MctsParams ¶ms) = 0; // Enqueue a schedule to be evaluated. Will annotate the value located at cost_ptr when the evaluation takes place. // Note that the dag argument should correspond to the dag specified previously when calling set_pipeline_features. diff --git a/src/autoschedulers/mcts/DefaultCostModel.cpp b/src/autoschedulers/mcts/DefaultCostModel.cpp index 1fc011c8b3fa..6f8889fb46ed 100644 --- a/src/autoschedulers/mcts/DefaultCostModel.cpp +++ b/src/autoschedulers/mcts/DefaultCostModel.cpp @@ -47,7 +47,7 @@ bool ends_with(const std::string &str, const std::string &suffix) { } // namespace void DefaultCostModel::set_pipeline_features(const Internal::Autoscheduler::FunctionDAG &dag, - const MachineParams ¶ms) { + const Internal::Autoscheduler::MctsParams ¶ms) { const int pipeline_feat_size = head1_w * head1_h; // We ignore the first seven pipeline features in the cost diff --git a/src/autoschedulers/mcts/DefaultCostModel.h b/src/autoschedulers/mcts/DefaultCostModel.h index 11dff14ef0dc..6e69b054ccd3 100644 --- a/src/autoschedulers/mcts/DefaultCostModel.h +++ b/src/autoschedulers/mcts/DefaultCostModel.h @@ -7,6 +7,12 @@ namespace Halide { +namespace Internal { +namespace Autoscheduler { +struct MctsParams; +} // namespace Autoscheduler +} // namespace Internal + class DefaultCostModel : public CostModel { private: Internal::Weights weights; @@ -37,7 +43,7 @@ class DefaultCostModel : public CostModel { // Configure the cost model for the algorithm to be scheduled. void set_pipeline_features(const Internal::Autoscheduler::FunctionDAG &dag, - const MachineParams ¶ms) override; + const Internal::Autoscheduler::MctsParams ¶ms) override; void set_pipeline_features(const Runtime::Buffer &, int n); // Enqueue a schedule to be evaluated. The second version of this method returns a buffer of diff --git a/src/autoschedulers/mcts/FunctionDAG.cpp b/src/autoschedulers/mcts/FunctionDAG.cpp index 3530cd22f9f9..08f53273c08f 100644 --- a/src/autoschedulers/mcts/FunctionDAG.cpp +++ b/src/autoschedulers/mcts/FunctionDAG.cpp @@ -549,7 +549,7 @@ void FunctionDAG::Edge::expand_footprint(const Span *consumer_loop, Span *produc } } -FunctionDAG::FunctionDAG(const vector &outputs, const MachineParams ¶ms, const Target &target) { +FunctionDAG::FunctionDAG(const vector &outputs, const Target &target) { map env = build_environment(outputs); // A mutator to apply parameter estimates to the expressions diff --git a/src/autoschedulers/mcts/FunctionDAG.h b/src/autoschedulers/mcts/FunctionDAG.h index 492ec3067e50..b693e6df7132 100644 --- a/src/autoschedulers/mcts/FunctionDAG.h +++ b/src/autoschedulers/mcts/FunctionDAG.h @@ -563,7 +563,7 @@ struct FunctionDAG { // Create the function DAG, and do all the dependency and cost // analysis. This is done once up-front before the tree search. - FunctionDAG(const vector &outputs, const MachineParams ¶ms, const Target &target); + FunctionDAG(const vector &outputs, const Target &target); void dump() const; std::ostream &dump(std::ostream &os) const; diff --git a/src/autoschedulers/mcts/LoopNest.cpp b/src/autoschedulers/mcts/LoopNest.cpp index ac23129f9d37..1f8fb6465c8f 100644 --- a/src/autoschedulers/mcts/LoopNest.cpp +++ b/src/autoschedulers/mcts/LoopNest.cpp @@ -228,7 +228,7 @@ void LoopNest::get_sites(StageMap &sites, // Do a recursive walk over the loop nest computing features to feed the cost model. void LoopNest::compute_features(const FunctionDAG &dag, - const MachineParams ¶ms, + const MctsParams ¶ms, const StageMap &sites, int64_t instances, int64_t parallelism, @@ -1288,7 +1288,7 @@ void LoopNest::compute_here(const FunctionDAG::Node *f, bool tileable, int v) { } // Parallelize this loop according to the given tiling. -IntrusivePtr LoopNest::parallelize_in_tiles(const MachineParams ¶ms, +IntrusivePtr LoopNest::parallelize_in_tiles(const MctsParams ¶ms, const vector &tiling, const LoopNest *parent) const { @@ -1356,7 +1356,7 @@ IntrusivePtr LoopNest::parallelize_in_tiles(const MachineParams // this loop nest. vector> LoopNest::compute_in_tiles(const FunctionDAG::Node *f, const LoopNest *parent, - const MachineParams ¶ms, + const MctsParams ¶ms, int v, bool in_realization) const { internal_assert(f); diff --git a/src/autoschedulers/mcts/LoopNest.h b/src/autoschedulers/mcts/LoopNest.h index 78b20c3032f3..f073da89b6c8 100644 --- a/src/autoschedulers/mcts/LoopNest.h +++ b/src/autoschedulers/mcts/LoopNest.h @@ -6,6 +6,7 @@ #ifndef LOOP_NEST_H #define LOOP_NEST_H +#include "CostModel.h" #include "FunctionDAG.h" #include "PerfectHashMap.h" #include @@ -126,7 +127,7 @@ struct LoopNest { // Do a recursive walk over the loop nest computing features to feed the cost model. void compute_features(const FunctionDAG &dag, - const MachineParams ¶ms, + const MctsParams ¶ms, const StageMap &sites, int64_t instances, int64_t parallelism, @@ -185,7 +186,7 @@ struct LoopNest { void compute_here(const FunctionDAG::Node *f, bool tileable, int v); // Parallelize this loop according to the given tiling. - IntrusivePtr parallelize_in_tiles(const MachineParams ¶ms, + IntrusivePtr parallelize_in_tiles(const MctsParams ¶ms, const vector &tiling, const LoopNest *parent) const; @@ -193,7 +194,7 @@ struct LoopNest { // this loop nest. std::vector> compute_in_tiles(const FunctionDAG::Node *f, const LoopNest *parent, - const MachineParams ¶ms, + const MctsParams ¶ms, int v, bool in_realization) const; diff --git a/src/autoschedulers/mcts/cost_model_generator.cpp b/src/autoschedulers/mcts/cost_model_generator.cpp index 536955f6e4b8..ad52c89d7d2b 100644 --- a/src/autoschedulers/mcts/cost_model_generator.cpp +++ b/src/autoschedulers/mcts/cost_model_generator.cpp @@ -123,7 +123,7 @@ class CostModel : public Generator> { using Input = GeneratorInput; template using Output = GeneratorOutput; - using Generator>::auto_schedule; + using Generator>::using_autoscheduler; using Generator>::get_pipeline; // Number of pipeline stages @@ -483,9 +483,9 @@ class CostModel : public Generator> { true_runtime.set_estimates({{0, 80}}); // SCHEDULE - if (training && !auto_schedule) { + if (training && !using_autoscheduler()) { do_cost_model_schedule(get_pipeline()); - } else if (auto_schedule) { + } else if (using_autoscheduler()) { // Do nothing. } else { // We just write down a good schedule for diff --git a/src/autoschedulers/mcts/included_schedule_file_generator.cpp b/src/autoschedulers/mcts/included_schedule_file_generator.cpp index 1a5cb99a784f..7103ec6c80e8 100644 --- a/src/autoschedulers/mcts/included_schedule_file_generator.cpp +++ b/src/autoschedulers/mcts/included_schedule_file_generator.cpp @@ -37,7 +37,7 @@ struct IncludedScheduleFile : public Halide::Generator { relu.set_estimates({{0, CO}, {0, W}, {0, H}, {0, N}}); // Schedule - if (auto_schedule) { + if (using_autoscheduler()) { // nothing } else { #if defined(GENERATING_SCHEDULE) diff --git a/src/autoschedulers/mcts/test.cpp b/src/autoschedulers/mcts/test.cpp index c710871c6185..221743dbb477 100644 --- a/src/autoschedulers/mcts/test.cpp +++ b/src/autoschedulers/mcts/test.cpp @@ -10,7 +10,8 @@ int main(int argc, char **argv) { load_plugin(argv[1]); - MachineParams params(32, 16000000, 40); + AutoschedulerParams params("MCTS", {}); + // Use a fixed target for the analysis to get consistent results from this test. Target target("x86-64-linux-sse41-avx-avx2"); @@ -25,7 +26,7 @@ int main(int argc, char **argv) { h.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); - Pipeline(h).auto_schedule(target, params); + Pipeline(h).apply_autoscheduler(target, params); } if (1) { @@ -45,7 +46,7 @@ int main(int argc, char **argv) { h.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); - Pipeline(h).auto_schedule(target, params); + Pipeline(h).apply_autoscheduler(target, params); } if (1) { @@ -58,7 +59,7 @@ int main(int argc, char **argv) { h.set_estimate(x, 0, 2048).set_estimate(y, 0, 2048); - Pipeline(h).auto_schedule(target, params); + Pipeline(h).apply_autoscheduler(target, params); } // Smaller footprint stencil -> smaller tiles @@ -71,7 +72,7 @@ int main(int argc, char **argv) { h.set_estimate(x, 0, 2048).set_estimate(y, 0, 2048); - Pipeline(h).auto_schedule(target, params); + Pipeline(h).apply_autoscheduler(target, params); } // A stencil chain @@ -90,7 +91,7 @@ int main(int argc, char **argv) { } f[N - 1].set_estimate(x, 0, 2048).set_estimate(y, 0, 2048); - Pipeline(f[N - 1]).auto_schedule(target, params); + Pipeline(f[N - 1]).apply_autoscheduler(target, params); } // An outer product @@ -101,7 +102,7 @@ int main(int argc, char **argv) { f.set_estimate(x, 0, 2048).set_estimate(y, 0, 2048); - Pipeline(f).auto_schedule(target, params); + Pipeline(f).apply_autoscheduler(target, params); } // A separable downsample that models the start of local_laplacian @@ -120,7 +121,7 @@ int main(int argc, char **argv) { downx(x, y, k) = downy(2 * x - 1, y, k) + downy(2 * x, y, k) + downy(2 * x + 1, y, k) + downy(2 * x + 2, y, k); downx.set_estimate(x, 1, 1022).set_estimate(y, 1, 1022).set_estimate(k, 0, 256); - Pipeline(downx).auto_schedule(target, params); + Pipeline(downx).apply_autoscheduler(target, params); } // A Func with multiple stages, some of which include additional loops @@ -139,7 +140,7 @@ int main(int argc, char **argv) { g.set_estimate(x, 1, 1022).set_estimate(y, 1, 1022); - Pipeline(g).auto_schedule(target, params); + Pipeline(g).apply_autoscheduler(target, params); } if (1) { @@ -163,7 +164,7 @@ int main(int argc, char **argv) { after[4].set_estimate(x, 0, 1024).set_estimate(y, 0, 1024); - Pipeline(after[4]).auto_schedule(target, params); + Pipeline(after[4]).apply_autoscheduler(target, params); } if (1) { @@ -183,7 +184,7 @@ int main(int argc, char **argv) { f_u64_2.set_estimate(x, 0, 1024 * 1024); - Pipeline(f_u64_2).auto_schedule(target, params); + Pipeline(f_u64_2).apply_autoscheduler(target, params); } if (1) { @@ -202,7 +203,7 @@ int main(int argc, char **argv) { out.set_estimate(j, 0, 1024).set_estimate(i, 0, 1024); - Pipeline(out).auto_schedule(target, params); + Pipeline(out).apply_autoscheduler(target, params); } if (1) { @@ -232,7 +233,7 @@ int main(int argc, char **argv) { p3[N - 1].set_estimate(x, 0, 1024).set_estimate(y, 0, 1024); - Pipeline(p3[N - 1]).auto_schedule(target, params); + Pipeline(p3[N - 1]).apply_autoscheduler(target, params); } if (1) { @@ -252,7 +253,7 @@ int main(int argc, char **argv) { out.set_estimate(x, 0, 10); - Pipeline(out).auto_schedule(target, params); + Pipeline(out).apply_autoscheduler(target, params); } if (1) { @@ -267,7 +268,7 @@ int main(int argc, char **argv) { g.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); - Pipeline(g).auto_schedule(target, params); + Pipeline(g).apply_autoscheduler(target, params); } if (1) { @@ -282,7 +283,7 @@ int main(int argc, char **argv) { h.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); - Pipeline(h).auto_schedule(target, params); + Pipeline(h).apply_autoscheduler(target, params); } if (1) { @@ -300,7 +301,7 @@ int main(int argc, char **argv) { a.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); b.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); - Pipeline({a, b}).auto_schedule(target, params); + Pipeline({a, b}).apply_autoscheduler(target, params); } if (1) { @@ -311,7 +312,7 @@ int main(int argc, char **argv) { g(x, y) = f(x, y); g.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); - Pipeline(g).auto_schedule(target, params); + Pipeline(g).apply_autoscheduler(target, params); } if (1) { @@ -321,7 +322,7 @@ int main(int argc, char **argv) { f(x, y) = im(x, y) * 7; f.set_estimate(x, 0, 3).set_estimate(y, 0, 5); - Pipeline(f).auto_schedule(target, params); + Pipeline(f).apply_autoscheduler(target, params); } if (1) { @@ -338,7 +339,7 @@ int main(int argc, char **argv) { .set_estimate(t, 0, 3) .set_estimate(u, 0, 2) .set_estimate(v, 0, 6); - Pipeline(f).auto_schedule(target, params); + Pipeline(f).apply_autoscheduler(target, params); } if (1) { @@ -357,7 +358,7 @@ int main(int argc, char **argv) { out1.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); out2.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); - Pipeline({out1, out2}).auto_schedule(target, params); + Pipeline({out1, out2}).apply_autoscheduler(target, params); } if (1) { @@ -382,7 +383,7 @@ int main(int argc, char **argv) { g(x, y) = f[N - 1](x, y) + f[0](clamp(cast(sin(x) * 10000), 0, 100000), clamp(cast(sin(x * y) * 10000), 0, 100000)); g.set_estimate(x, 0, 2048).set_estimate(y, 0, 2048); - Pipeline(g).auto_schedule(target, params); + Pipeline(g).apply_autoscheduler(target, params); } if (1) { @@ -396,7 +397,7 @@ int main(int argc, char **argv) { h(x, y) = g() + x + y; h.set_estimate(x, 0, 1024).set_estimate(y, 0, 2048); - Pipeline(h).auto_schedule(target, params); + Pipeline(h).apply_autoscheduler(target, params); } if (1) { @@ -411,7 +412,7 @@ int main(int argc, char **argv) { g(x, y) = f(x, y); g.set_estimate(x, 0, 10).set_estimate(y, 0, 2048); - Pipeline(g).auto_schedule(target, params); + Pipeline(g).apply_autoscheduler(target, params); } if (1) { @@ -443,7 +444,7 @@ int main(int argc, char **argv) { out(x, y) = up[0](x, y); out.set_estimate(x, 0, 2048).set_estimate(y, 0, 2048); - Pipeline(out).auto_schedule(target, params); + Pipeline(out).apply_autoscheduler(target, params); } if (1) { @@ -461,7 +462,7 @@ int main(int argc, char **argv) { casted(x, y) = scan(x, y); casted.set_estimate(x, 0, 2000).set_estimate(y, 0, 2000); - Pipeline(casted).auto_schedule(target, params); + Pipeline(casted).apply_autoscheduler(target, params); } if (1) { @@ -477,7 +478,7 @@ int main(int argc, char **argv) { f.set_estimate(x, 0, 2000).set_estimate(y, 0, 2000); output.set_estimate(i, 0, 256); - Pipeline(output).auto_schedule(target, params); + Pipeline(output).apply_autoscheduler(target, params); } return 0; diff --git a/src/autoschedulers/mcts/test_function_dag.cpp b/src/autoschedulers/mcts/test_function_dag.cpp index 933c7f9d5027..21702ae42385 100644 --- a/src/autoschedulers/mcts/test_function_dag.cpp +++ b/src/autoschedulers/mcts/test_function_dag.cpp @@ -31,7 +31,7 @@ extern "C" int mul_by_two( return 0; } -void test_coeff_wise(const MachineParams ¶ms, const Target &target) { +void test_coeff_wise(const Target &target) { Var x("x"), y("y"); std::ostringstream with_extern; @@ -55,7 +55,7 @@ void test_coeff_wise(const MachineParams ¶ms, const Target &target) { h.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); std::vector v; v.push_back(h.function()); - Halide::Internal::Autoscheduler::FunctionDAG d(v, params, target); + Halide::Internal::Autoscheduler::FunctionDAG d(v, target); d.dump(with_extern); } @@ -70,7 +70,7 @@ void test_coeff_wise(const MachineParams ¶ms, const Target &target) { h.set_estimate(x, 0, 1000).set_estimate(y, 0, 1000); std::vector v; v.push_back(h.function()); - Halide::Internal::Autoscheduler::FunctionDAG d(v, params, target); + Halide::Internal::Autoscheduler::FunctionDAG d(v, target); d.dump(without_extern); } @@ -113,7 +113,7 @@ extern "C" int matmul( return 0; } -void test_matmul(const MachineParams ¶ms, const Target &target) { +void test_matmul(const Target &target) { Var x("x"), y("y"), k("k"); RDom r(0, 200); Halide::Buffer input1(200, 200); @@ -140,7 +140,7 @@ void test_matmul(const MachineParams ¶ms, const Target &target) { h.set_estimate(x, 0, 200).set_estimate(y, 0, 200); std::vector v; v.push_back(h.function()); - Halide::Internal::Autoscheduler::FunctionDAG d(v, params, target); + Halide::Internal::Autoscheduler::FunctionDAG d(v, target); d.dump(with_extern); } @@ -153,7 +153,7 @@ void test_matmul(const MachineParams ¶ms, const Target &target) { h.set_estimate(x, 0, 200).set_estimate(y, 0, 200); std::vector v; v.push_back(h.function()); - Halide::Internal::Autoscheduler::FunctionDAG d(v, params, target); + Halide::Internal::Autoscheduler::FunctionDAG d(v, target); d.dump(without_extern); } @@ -164,11 +164,10 @@ void test_matmul(const MachineParams ¶ms, const Target &target) { int main(int argc, char **argv) { // Use a fixed target for the analysis to get consistent results from this test. - MachineParams params(32, 16000000, 40); Target target("x86-64-linux-sse41-avx-avx2"); - test_coeff_wise(params, target); - test_matmul(params, target); + test_coeff_wise(target); + test_matmul(target); return 0; } diff --git a/src/autoschedulers/mullapudi2016/AutoSchedule.cpp b/src/autoschedulers/mullapudi2016/AutoSchedule.cpp index 6253b8229c46..82bb5b1a208c 100644 --- a/src/autoschedulers/mullapudi2016/AutoSchedule.cpp +++ b/src/autoschedulers/mullapudi2016/AutoSchedule.cpp @@ -7,9 +7,11 @@ #include #include "Halide.h" +#include "ParamParser.h" namespace Halide { namespace Internal { +namespace Autoscheduler { using std::make_pair; using std::map; @@ -20,6 +22,18 @@ using std::vector; namespace { +struct ArchParams { + /** Maximum level of parallelism avalaible. */ + int parallelism = 16; + + /** Size of the last-level cache (in bytes). */ + uint64_t last_level_cache_size = 16 * 1024 * 1024; + + /** Indicates how much more expensive is the cost of a load compared to + * the cost of an arithmetic operation at last level cache. */ + float balance = 40; +}; + // Substitute parameter estimates into the exprs describing the box bounds. void substitute_estimates_box(Box &box) { box.used = substitute_var_estimates(box.used); @@ -1054,7 +1068,7 @@ struct Partitioner { const map &pipeline_bounds; // Parameters of the machine model that is used for estimating the cost of each // group in the pipeline. - const MachineParams &arch_params; + const ArchParams &arch_params; // Dependency analysis of the pipeline. This support queries on regions // accessed and computed for producing some regions of some functions. DependenceAnalysis &dep_analysis; @@ -1065,7 +1079,7 @@ struct Partitioner { const vector &outputs; Partitioner(const map &_pipeline_bounds, - const MachineParams &_arch_params, + const ArchParams &_arch_params, const vector &_outputs, DependenceAnalysis &_dep_analysis, RegionCosts &_costs); @@ -1305,7 +1319,7 @@ void Partitioner::disp_pipeline_costs() { // Construct a partitioner and build the pipeline graph on which the grouping // algorithm operates. Partitioner::Partitioner(const map &_pipeline_bounds, - const MachineParams &_arch_params, + const ArchParams &_arch_params, const vector &_outputs, DependenceAnalysis &_dep_analysis, RegionCosts &_costs) @@ -3166,7 +3180,7 @@ bool inline_unbounded(const vector &outputs, // outputs. This applies the schedules and returns a string representation of // the schedules. The target architecture is specified by 'target'. string generate_schedules(const vector &outputs, const Target &target, - const MachineParams &arch_params) { + const ArchParams &arch_params) { // Make an environment map which is used throughout the auto scheduling process. map env; for (const Function &f : outputs) { @@ -3372,25 +3386,53 @@ string generate_schedules(const vector &outputs, const Target &target, } struct Mullapudi2016 { - void operator()(const Pipeline &pipeline, const Target &target, const MachineParams &arch_params, AutoSchedulerResults *outputs) { +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + void operator()(const Pipeline &pipeline, const Target &target, const MachineParams ¶ms_in, AutoSchedulerResults *outputs) { AutoSchedulerResults results; results.target = target; - results.machine_params_string = arch_params.to_string(); + results.machine_params_string = params_in.to_string(); results.scheduler_name = "Mullapudi2016"; std::vector pipeline_outputs; for (const Func &f : pipeline.outputs()) { pipeline_outputs.push_back(f.function()); } + ArchParams arch_params{params_in.parallelism, params_in.last_level_cache_size, params_in.balance}; results.schedule_source = generate_schedules(pipeline_outputs, target, arch_params); // this autoscheduler has no featurization + *outputs = std::move(results); + } +#else + void operator()(const Pipeline &pipeline, const Target &target, const AutoschedulerParams ¶ms_in, AutoSchedulerResults *outputs) { + internal_assert(params_in.name == "Mullapudi2016"); - *outputs = results; + AutoSchedulerResults results; + results.target = target; + results.autoscheduler_params = params_in; + + std::vector pipeline_outputs; + for (const Func &f : pipeline.outputs()) { + pipeline_outputs.push_back(f.function()); + } + + ArchParams arch_params; + { + ParamParser parser(params_in.extra); + parser.parse("parallelism", &arch_params.parallelism); + parser.parse("last_level_cache_size", &arch_params.last_level_cache_size); + parser.parse("balance", &arch_params.balance); + parser.finish(); + } + results.schedule_source = generate_schedules(pipeline_outputs, target, arch_params); + results.autoscheduler_params = params_in; + // this autoscheduler has no featurization + *outputs = std::move(results); } +#endif }; REGISTER_AUTOSCHEDULER(Mullapudi2016) +} // namespace Autoscheduler } // namespace Internal - } // namespace Halide diff --git a/src/autoschedulers/mullapudi2016/CMakeLists.txt b/src/autoschedulers/mullapudi2016/CMakeLists.txt index 41a21ab1b086..7b7b3cfa3162 100644 --- a/src/autoschedulers/mullapudi2016/CMakeLists.txt +++ b/src/autoschedulers/mullapudi2016/CMakeLists.txt @@ -1 +1,2 @@ add_autoscheduler(NAME Mullapudi2016 SOURCES AutoSchedule.cpp) +target_link_libraries(Halide_Mullapudi2016 PRIVATE ParamParser) diff --git a/src/autoschedulers/mullapudi2016/Makefile b/src/autoschedulers/mullapudi2016/Makefile index 14eddc0e1128..f79974585a41 100644 --- a/src/autoschedulers/mullapudi2016/Makefile +++ b/src/autoschedulers/mullapudi2016/Makefile @@ -15,6 +15,9 @@ endif CXXFLAGS += -I$(COMMON_DIR) -$(BIN)/libautoschedule_mullapudi2016.$(SHARED_EXT): $(SRC)/AutoSchedule.cpp $(LIB_HALIDE) +# Be sure *not* to include libHalide in the link steps here; that can cause misbehavior +# on OSX systems in certain situations -- note that $(LIB_HALIDE) is an order-only dep, +# to ensure that (eg) Halide.h is built before this. +$(BIN)/libautoschedule_mullapudi2016.$(PLUGIN_EXT): $(SRC)/AutoSchedule.cpp | $(LIB_HALIDE) @mkdir -p $(@D) $(CXX) -shared $(USE_EXPORT_DYNAMIC) -fPIC -fvisibility=hidden -fvisibility-inlines-hidden $(CXXFLAGS) $(OPTIMIZE) $^ -o $@ $(HALIDE_RPATH_FOR_LIB) diff --git a/src/exported_symbols.ldscript b/src/exported_symbols.ldscript index c7002c5c1fdc..2b9f50a68de7 100644 --- a/src/exported_symbols.ldscript +++ b/src/exported_symbols.ldscript @@ -9,6 +9,8 @@ _Z?6Halide* ; _Z??6Halide* ; _Z???6Halide* ; + # non-virtual thunks + _ZThn???_N6Halide* ; local: *; }; diff --git a/src/exported_symbols.osx b/src/exported_symbols.osx index 1d55d665697a..089b3d741970 100644 --- a/src/exported_symbols.osx +++ b/src/exported_symbols.osx @@ -5,3 +5,5 @@ halide_* __Z?6Halide* __Z??6Halide* __Z???6Halide* +# non-virtual thunks +__ZThn???_N6Halide* diff --git a/src/runtime/CMakeLists.txt b/src/runtime/CMakeLists.txt index 34d5a27b181b..946784f662d5 100644 --- a/src/runtime/CMakeLists.txt +++ b/src/runtime/CMakeLists.txt @@ -1,7 +1,3 @@ -if (NOT CMAKE_GENERATOR MATCHES "Make|Ninja") - message(STATUS "Notice: ${CMAKE_GENERATOR} does not support depfile dependencies. Incremental builds may fail.") -endif () - # Keep these lists in alphabetical order. set(RUNTIME_CPP aarch64_cpu_features @@ -22,6 +18,7 @@ set(RUNTIME_CPP fake_get_symbol fake_thread_pool float16_t + force_include_types fuchsia_clock fuchsia_host_cpu_count fuchsia_yield @@ -36,8 +33,6 @@ set(RUNTIME_CPP linux_clock linux_host_cpu_count linux_yield - matlab - metadata metal metal_objc_arm metal_objc_x86 @@ -236,28 +231,7 @@ foreach (i IN LISTS RUNTIME_CPP) set(INITMOD "_initmod_${i}_${j}${SUFFIX}.cpp") set(SYMBOL "halide_internal_initmod_${i}_${j}${SUFFIX}") - set(clang_flags ${RUNTIME_CXX_FLAGS} ${fpic} ${fshort-wchar} ${RUNTIME_DEFINES${SUFFIX}} -m${j} -target ${TARGET} -emit-llvm -S) - - set(ll_path "${LL}") - if (CMAKE_GENERATOR MATCHES "Ninja") - if (POLICY CMP0116) - # CMake 3.20+ does the right thing here and transforms the depfiles for us - list(APPEND clang_flags -MD -MF "${basename}.d") - set(dep_args DEPFILE "${basename}.d") - else() - # Dep-files are subtle and require clang to run using *just* the right - # relative paths to the build root, NOT the Halide build root. This is - # a perfect storm of bad behavior from CMake <3.20, Ninja, and Clang. - file(RELATIVE_PATH ll_path "${CMAKE_BINARY_DIR}" "${CMAKE_CURRENT_BINARY_DIR}/${LL}") - file(TO_NATIVE_PATH "${ll_path}" ll_path) - list(APPEND clang_flags -MD -MF "$") - set(dep_args - WORKING_DIRECTORY "${CMAKE_BINARY_DIR}" - DEPFILE "${CMAKE_CURRENT_BINARY_DIR}/${basename}.d") - endif() - elseif (CMAKE_GENERATOR MATCHES "Make") - set(dep_args IMPLICIT_DEPENDS CXX "${SOURCE}") - endif () + set(clang_flags ${RUNTIME_CXX_FLAGS} ${fpic} ${fshort-wchar} ${RUNTIME_DEFINES${SUFFIX}} -m${j} -target ${TARGET} -emit-llvm -S -MD -MF "${basename}.d") if (Halide_CLANG_TIDY_BUILD) # Create a 'fake' entry just so that clang-tidy will see a C++ compilation command @@ -268,9 +242,9 @@ foreach (i IN LISTS RUNTIME_CPP) target_compile_definitions(${basename} PRIVATE ${RUNTIME_DEFINES}) else() add_custom_command(OUTPUT "${LL}" - COMMAND clang ${clang_flags} -o "${ll_path}" "$" + COMMAND ${CMAKE_C_COMPILER_LAUNCHER} $ ${clang_flags} -o "${LL}" "$" DEPENDS "${SOURCE}" - ${dep_args} + DEPFILE "${basename}.d" VERBATIM) endif() diff --git a/src/runtime/HalideBuffer.h b/src/runtime/HalideBuffer.h index c0191f2af272..1cae329f4a00 100644 --- a/src/runtime/HalideBuffer.h +++ b/src/runtime/HalideBuffer.h @@ -2286,7 +2286,11 @@ class Buffer { template()((const int *)nullptr))> static void for_each_element(int, int dims, const for_each_element_task_dim *t, Fn &&f, int check = 0) { - int *pos = (int *)HALIDE_ALLOCA(dims * sizeof(int)); + const int size = dims * sizeof(int); + int *pos = (int *)HALIDE_ALLOCA(size); + // At least one version of GCC will (incorrectly) report that pos "may be used uninitialized". + // Add this memset to silence it. + memset(pos, 0, size); for_each_element_array(dims - 1, t, std::forward(f), pos); } diff --git a/src/runtime/HalideRuntime.h b/src/runtime/HalideRuntime.h index 6089110420a7..62fc35640eb2 100644 --- a/src/runtime/HalideRuntime.h +++ b/src/runtime/HalideRuntime.h @@ -55,6 +55,24 @@ extern "C" { #endif #endif +// Annotation for AOT and JIT calls -- if undefined, use no annotation. +// To ensure that all results are checked, do something like +// +// -DHALIDE_FUNCTION_ATTRS=HALIDE_MUST_USE_RESULT +// +// in your C++ compiler options +#ifndef HALIDE_FUNCTION_ATTRS +#define HALIDE_FUNCTION_ATTRS +#endif + +#ifndef HALIDE_EXPORT_SYMBOL +#ifdef _MSC_VER +#define HALIDE_EXPORT_SYMBOL __declspec(dllexport) +#else +#define HALIDE_EXPORT_SYMBOL __attribute__((visibility("default"))) +#endif +#endif + /** \file * * This file declares the routines used by Halide internally in its @@ -1071,12 +1089,8 @@ enum halide_error_code_t { * violates a Halide invariant. */ halide_error_code_no_device_interface = -19, - /** An error occurred when attempting to initialize the Matlab - * runtime. */ - halide_error_code_matlab_init_failed = -20, - - /** The type of an mxArray did not match the expected type. */ - halide_error_code_matlab_bad_param_type = -21, + /* unused = -20, */ + /* unused = -21, */ /** There is a bug in the Halide compiler. */ halide_error_code_internal_error = -22, @@ -1295,8 +1309,6 @@ typedef enum halide_target_feature_t { halide_target_feature_user_context, ///< Generated code takes a user_context pointer as first argument - halide_target_feature_matlab, ///< Generate a mexFunction compatible with Matlab mex libraries. See tools/mex_halide.m. - halide_target_feature_profile, ///< Launch a sampling profiler alongside the Halide pipeline that monitors and reports the runtime used by each Func halide_target_feature_no_runtime, ///< Do not include a copy of the Halide runtime in any generated object file or assembly @@ -1332,6 +1344,9 @@ typedef enum halide_target_feature_t { halide_target_feature_hexagon_dma, ///< Enable Hexagon DMA buffers. halide_target_feature_embed_bitcode, ///< Emulate clang -fembed-bitcode flag. halide_target_feature_enable_llvm_loop_opt, ///< Enable loop vectorization + unrolling in LLVM. Overrides halide_target_feature_disable_llvm_loop_opt. (Ignored for non-LLVM targets.) + // halide_target_feature_disable_llvm_loop_opt is deprecated in Halide 15 + // (and will be removed in Halide 16). Halide 15 now defaults to disabling + // LLVM loop optimization, unless halide_target_feature_enable_llvm_loop_opt is set. halide_target_feature_disable_llvm_loop_opt, ///< Disable loop vectorization + unrolling in LLVM. (Ignored for non-LLVM targets.) halide_target_feature_wasm_simd128, ///< Enable +simd128 instructions for WebAssembly codegen. halide_target_feature_wasm_signext, ///< Enable +sign-ext instructions for WebAssembly codegen. @@ -1348,6 +1363,7 @@ typedef enum halide_target_feature_t { halide_target_feature_armv81a, ///< Enable ARMv8.1-a instructions halide_target_feature_sanitizer_coverage, ///< Enable hooks for SanitizerCoverage support. halide_target_feature_profile_by_timer, ///< Alternative to halide_target_feature_profile using timer interrupt for systems without threads or applicartions that need to avoid them. + halide_target_feature_spirv, ///< Enable SPIR-V code generation support. halide_target_feature_end ///< A sentinel. Every target is considered to have this feature, and setting this feature does nothing. } halide_target_feature_t; diff --git a/src/runtime/HalideRuntimeHexagonDma.h b/src/runtime/HalideRuntimeHexagonDma.h index 42b1ea35dc31..3eb9c9c58b82 100644 --- a/src/runtime/HalideRuntimeHexagonDma.h +++ b/src/runtime/HalideRuntimeHexagonDma.h @@ -49,7 +49,7 @@ typedef enum { extern const struct halide_device_interface_t *halide_hexagon_dma_device_interface(); -/** This API is used to set up the DMA device interface to be used for DMA transfer. This also internally +/** This API is used to set up the DMA device interface to be used for DMA transfer. This also internally * creates the DMA device handle and populates all the Buffer related parameters (width, height, stride) * to be used for DMA configuration. */ @@ -66,7 +66,7 @@ extern int halide_hexagon_dma_device_detach_native(void *user_context, struct ha */ extern int halide_hexagon_dma_allocate_engine(void *user_context, void **dma_engine); -/** This API free up the allocated DMA engine. This need to be called after a user program ends +/** This API free up the allocated DMA engine. This need to be called after a user program ends * all the DMA Operations and make it available for subsequent DMA transfers */ extern int halide_hexagon_dma_deallocate_engine(void *user_context, void *dma_engine); @@ -83,7 +83,7 @@ extern int halide_hexagon_dma_prepare_for_copy_to_device(void *user_context, str void *dma_engine, bool is_ubwc, halide_hexagon_image_fmt_t fmt); -/** This API is used to frees up the DMA Resources associated with the buffer. +/** This API is used to frees up the DMA Resources associated with the buffer. * TODO: Currently this API is a dummy as all the necessary freeing is done in an another API. * This will be used in future. */ diff --git a/src/runtime/HalideRuntimeOpenGLCompute.h b/src/runtime/HalideRuntimeOpenGLCompute.h index decca61124f1..1b19472908ab 100644 --- a/src/runtime/HalideRuntimeOpenGLCompute.h +++ b/src/runtime/HalideRuntimeOpenGLCompute.h @@ -59,7 +59,7 @@ void *halide_opengl_get_proc_address(void *user_context, const char *name); /** This function creates an OpenGL context for use by the OpenGL backend. * * You may have to implement this yourself as well. Halide only provides -* implementations for some platforms." + * implementations for some platforms." */ int halide_opengl_create_context(void *user_context); diff --git a/src/runtime/aarch64.ll b/src/runtime/aarch64.ll index a9fcfdc35496..9ae3b8e46ac2 100644 --- a/src/runtime/aarch64.ll +++ b/src/runtime/aarch64.ll @@ -64,7 +64,7 @@ declare <8 x half> @llvm.aarch64.neon.frsqrts.v8f16(<8 x half> %x, <8 x half> %y declare <4 x half> @llvm.aarch64.neon.frsqrts.v4f16(<4 x half> %x, <4 x half> %y) nounwind readnone; define weak_odr float @fast_inverse_f32(float %x) nounwind alwaysinline { - %vec = insertelement <2 x float> undef, float %x, i32 0 + %vec = insertelement <2 x float> poison, float %x, i32 0 %approx = tail call <2 x float> @fast_inverse_f32x2(<2 x float> %vec) %result = extractelement <2 x float> %approx, i32 0 ret float %result @@ -85,7 +85,7 @@ define weak_odr <4 x float> @fast_inverse_f32x4(<4 x float> %x) nounwind alwaysi } define weak_odr half @fast_inverse_f16(half %x) nounwind alwaysinline { - %vec = insertelement <4 x half> undef, half %x, i32 0 + %vec = insertelement <4 x half> poison, half %x, i32 0 %approx = tail call <4 x half> @fast_inverse_f16x4(<4 x half> %vec) %result = extractelement <4 x half> %approx, i32 0 ret half %result @@ -106,7 +106,7 @@ define weak_odr <8 x half> @fast_inverse_f16x8(<8 x half> %x) nounwind alwaysinl } define weak_odr float @fast_inverse_sqrt_f32(float %x) nounwind alwaysinline { - %vec = insertelement <2 x float> undef, float %x, i32 0 + %vec = insertelement <2 x float> poison, float %x, i32 0 %approx = tail call <2 x float> @fast_inverse_sqrt_f32x2(<2 x float> %vec) %result = extractelement <2 x float> %approx, i32 0 ret float %result @@ -129,7 +129,7 @@ define weak_odr <4 x float> @fast_inverse_sqrt_f32x4(<4 x float> %x) nounwind al } define weak_odr half @fast_inverse_sqrt_f16(half %x) nounwind alwaysinline { - %vec = insertelement <4 x half> undef, half %x, i32 0 + %vec = insertelement <4 x half> poison, half %x, i32 0 %approx = tail call <4 x half> @fast_inverse_sqrt_f16x4(<4 x half> %vec) %result = extractelement <4 x half> %approx, i32 0 ret half %result diff --git a/src/runtime/arm.ll b/src/runtime/arm.ll index 295f5cbaf4a6..4a0624c143dc 100644 --- a/src/runtime/arm.ll +++ b/src/runtime/arm.ll @@ -56,7 +56,7 @@ declare <4 x float> @llvm.arm.neon.vrsqrts.v4f32(<4 x float> %x, <4 x float> %y) declare <2 x float> @llvm.arm.neon.vrsqrts.v2f32(<2 x float> %x, <2 x float> %y) nounwind readnone; define weak_odr float @fast_inverse_f32(float %x) nounwind alwaysinline { - %vec = insertelement <2 x float> undef, float %x, i32 0 + %vec = insertelement <2 x float> poison, float %x, i32 0 %approx = tail call <2 x float> @fast_inverse_f32x2(<2 x float> %vec) %result = extractelement <2 x float> %approx, i32 0 ret float %result @@ -77,7 +77,7 @@ define weak_odr <4 x float> @fast_inverse_f32x4(<4 x float> %x) nounwind alwaysi } define weak_odr float @fast_inverse_sqrt_f32(float %x) nounwind alwaysinline { - %vec = insertelement <2 x float> undef, float %x, i32 0 + %vec = insertelement <2 x float> poison, float %x, i32 0 %approx = tail call <2 x float> @fast_inverse_sqrt_f32x2(<2 x float> %vec) %result = extractelement <2 x float> %approx, i32 0 ret float %result diff --git a/src/runtime/force_include_types.cpp b/src/runtime/force_include_types.cpp new file mode 100644 index 000000000000..f5eeda611180 --- /dev/null +++ b/src/runtime/force_include_types.cpp @@ -0,0 +1,21 @@ +#include "HalideRuntime.h" +#include "runtime_internal.h" + +namespace Halide { +namespace Runtime { +namespace Internal { + +struct AllTheTypes { + halide_filter_metadata_t a; + halide_filter_argument_t b; + halide_scalar_value_t c; + halide_semaphore_t d; +}; + +WEAK void halide_unused_force_include_types() { + static __attribute__((used)) AllTheTypes a; +} + +} // namespace Internal +} // namespace Runtime +} // namespace Halide diff --git a/src/runtime/hashmap.h b/src/runtime/hashmap.h index 479becce6f16..3c224b870963 100644 --- a/src/runtime/hashmap.h +++ b/src/runtime/hashmap.h @@ -426,4 +426,4 @@ struct THashMap : public HashMap { } // namespace Runtime } // namespace Halide -#endif //HALIDE_RUNTIME_HASHMAP_H +#endif // HALIDE_RUNTIME_HASHMAP_H diff --git a/src/runtime/hexagon_dma.cpp b/src/runtime/hexagon_dma.cpp index 4db1e81d8f93..4512abb31c72 100644 --- a/src/runtime/hexagon_dma.cpp +++ b/src/runtime/hexagon_dma.cpp @@ -163,7 +163,7 @@ void desc_pool_free(void *user_context) { // User ptovided Image format to DMA format conversion. inline t_eDmaFmt halide_hexagon_get_dma_format(void *user_context, const halide_hexagon_image_fmt_t format) { - //A giant switch case to match image formats to dma formats + // A giant switch case to match image formats to dma formats switch (format) { case halide_hexagon_fmt_NV12: return eDmaFmt_NV12; @@ -486,7 +486,7 @@ WEAK int halide_hexagon_dma_unprepare(void *user_context, struct halide_buffer_t debug(user_context) << "Hexagon: halide_hexagon_dma_unprepare (user_context: " << user_context << ", buf: " << *buf << ")\n"; - //TODO Now that FinishFrame is called by Hexagon DMA Pool Module, need to check if this function is redundant + // TODO Now that FinishFrame is called by Hexagon DMA Pool Module, need to check if this function is redundant return halide_error_code_success; } diff --git a/src/runtime/internal/block_allocator.h b/src/runtime/internal/block_allocator.h new file mode 100644 index 000000000000..8dd7e4fc6dfa --- /dev/null +++ b/src/runtime/internal/block_allocator.h @@ -0,0 +1,482 @@ +#ifndef HALIDE_RUNTIME_BLOCK_ALLOCATOR_H +#define HALIDE_RUNTIME_BLOCK_ALLOCATOR_H + +#include "linked_list.h" +#include "memory_resources.h" +#include "region_allocator.h" + +namespace Halide { +namespace Runtime { +namespace Internal { + +// -- + +/** Allocator class interface for managing large contiguous blocks + * of memory, which are then sub-allocated into smaller regions of + * memory. This class only manages the address creation for the + * regions -- allocation callback functions are used to request the + * memory from the necessary system or API calls. This class is + * intended to be used inside of a higher level memory management + * class that provides thread safety, policy management and API + * integration for a specific runtime API (eg Vulkan, OpenCL, etc) + */ +class BlockAllocator { +public: + // disable copy constructors and assignment + BlockAllocator(const BlockAllocator &) = delete; + BlockAllocator &operator=(const BlockAllocator &) = delete; + + // disable non-factory based construction + BlockAllocator() = delete; + ~BlockAllocator() = delete; + + // Allocators for the different types of memory we need to allocate + struct MemoryAllocators { + SystemMemoryAllocatorFns system; + MemoryBlockAllocatorFns block; + MemoryRegionAllocatorFns region; + }; + + // Runtime configuration parameters to adjust the behaviour of the block allocator + struct Config { + size_t initial_capacity = 0; + size_t minimum_block_size = 0; + size_t maximum_block_size = 0; + size_t maximum_block_count = 0; + }; + + // Factory methods for creation / destruction + static BlockAllocator *create(void *user_context, const Config &config, const MemoryAllocators &allocators); + static void destroy(void *user_context, BlockAllocator *block_allocator); + + // Public interface methods + MemoryRegion *reserve(void *user_context, const MemoryRequest &request); + void reclaim(void *user_context, MemoryRegion *region); + bool collect(void *user_context); //< returns true if any blocks were removed + void release(void *user_context); + void destroy(void *user_context); + + // Access methods + const MemoryAllocators ¤t_allocators() const; + const Config ¤t_config() const; + const Config &default_config() const; + size_t block_count() const; + +private: + // Linked-list for storing the block resources + typedef LinkedList::EntryType BlockEntry; + + // Initializes a new instance + void initialize(void *user_context, const Config &config, const MemoryAllocators &allocators); + + // Reserves a region of memory using the given allocator for the given block resource, returns nullptr on failure + MemoryRegion *reserve_memory_region(void *user_context, RegionAllocator *allocator, const MemoryRequest &request); + + // Creates a new region allocator for the given block resource + RegionAllocator *create_region_allocator(void *user_context, BlockResource *block); + + // Destroys the given region allocator and all associated memory regions + void destroy_region_allocator(void *user_context, RegionAllocator *region_allocator); + + // Reserves a block of memory for the requested size and returns the corresponding block entry, or nullptr on failure + BlockEntry *reserve_block_entry(void *user_context, const MemoryProperties &properties, size_t size, bool dedicated); + + // Locates the "best-fit" block entry for the requested size, or nullptr if none was found + BlockEntry *find_block_entry(void *user_context, const MemoryProperties &properties, size_t size, bool dedicated); + + // Creates a new block entry and int the list + BlockEntry *create_block_entry(void *user_context, const MemoryProperties &properties, size_t size, bool dedicated); + + // Releases the block entry from being used, and makes it available for further allocations + void release_block_entry(void *user_context, BlockEntry *block_entry); + + // Destroys the block entry and removes it from the list + void destroy_block_entry(void *user_context, BlockEntry *block_entry); + + // Invokes the allocation callback to allocate memory for the block region + void alloc_memory_block(void *user_context, BlockResource *block); + + // Invokes the deallocation callback to free memory for the memory block + void free_memory_block(void *user_context, BlockResource *block); + + // Returns a constrained size for the requested size based on config parameters + size_t constrain_requested_size(size_t size) const; + + // Returns true if the given block is compatible with the given properties + bool is_compatible_block(const BlockResource *block, const MemoryProperties &properties) const; + + Config config; + LinkedList block_list; + MemoryAllocators allocators; +}; + +BlockAllocator *BlockAllocator::create(void *user_context, const Config &cfg, const MemoryAllocators &allocators) { + halide_abort_if_false(user_context, allocators.system.allocate != nullptr); + BlockAllocator *result = reinterpret_cast( + allocators.system.allocate(user_context, sizeof(BlockAllocator))); + + if (result == nullptr) { + error(user_context) << "BlockAllocator: Failed to create instance! Out of memory!\n"; + return nullptr; + } + + result->initialize(user_context, cfg, allocators); + return result; +} + +void BlockAllocator::destroy(void *user_context, BlockAllocator *instance) { + halide_abort_if_false(user_context, instance != nullptr); + const MemoryAllocators &allocators = instance->allocators; + instance->destroy(user_context); + halide_abort_if_false(user_context, allocators.system.deallocate != nullptr); + allocators.system.deallocate(user_context, instance); +} + +void BlockAllocator::initialize(void *user_context, const Config &cfg, const MemoryAllocators &ma) { + config = cfg; + allocators = ma; + block_list.initialize(user_context, + sizeof(BlockResource), + config.initial_capacity, + allocators.system); +} + +MemoryRegion *BlockAllocator::reserve(void *user_context, const MemoryRequest &request) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "BlockAllocator: Reserve (" + << "user_context=" << (void *)(user_context) << " " + << "offset=" << (uint32_t)request.offset << " " + << "size=" << (uint32_t)request.size << " " + << "dedicated=" << (request.dedicated ? "true" : "false") << " " + << "usage=" << halide_memory_usage_name(request.properties.usage) << " " + << "caching=" << halide_memory_caching_name(request.properties.caching) << " " + << "visibility=" << halide_memory_visibility_name(request.properties.visibility) << ") ...\n"; +#endif + BlockEntry *block_entry = reserve_block_entry(user_context, request.properties, request.size, request.dedicated); + if (block_entry == nullptr) { + debug(user_context) << "BlockAllocator: Failed to allocate new empty block of requested size (" + << (int32_t)(request.size) << " bytes)!\n"; + return nullptr; + } + + BlockResource *block = static_cast(block_entry->value); + halide_abort_if_false(user_context, block != nullptr); + halide_abort_if_false(user_context, block->allocator != nullptr); + + MemoryRegion *result = reserve_memory_region(user_context, block->allocator, request); + if (result == nullptr) { + + // Unable to reserve region in an existing block ... create a new block and try again. + size_t actual_size = constrain_requested_size(request.size); + block_entry = create_block_entry(user_context, request.properties, actual_size, request.dedicated); + if (block_entry == nullptr) { + debug(user_context) << "BlockAllocator: Out of memory! Failed to allocate empty block of size (" + << (int32_t)(actual_size) << " bytes)!\n"; + return nullptr; + } + + block = static_cast(block_entry->value); + if (block->allocator == nullptr) { + block->allocator = create_region_allocator(user_context, block); + } + + result = reserve_memory_region(user_context, block->allocator, request); + } + return result; +} + +void BlockAllocator::reclaim(void *user_context, MemoryRegion *memory_region) { + halide_abort_if_false(user_context, memory_region != nullptr); + RegionAllocator *allocator = RegionAllocator::find_allocator(user_context, memory_region); + if (allocator == nullptr) { + return; + } + allocator->reclaim(user_context, memory_region); +} + +bool BlockAllocator::collect(void *user_context) { + bool result = false; + BlockEntry *block_entry = block_list.back(); + while (block_entry != nullptr) { + BlockEntry *prev_entry = block_entry->prev_ptr; + + const BlockResource *block = static_cast(block_entry->value); + if (block->allocator == nullptr) { + block_entry = prev_entry; + continue; + } + + block->allocator->collect(user_context); + if (block->reserved == 0) { + destroy_block_entry(user_context, block_entry); + result = true; + } + + block_entry = prev_entry; + } + return result; +} + +void BlockAllocator::release(void *user_context) { + BlockEntry *block_entry = block_list.back(); + while (block_entry != nullptr) { + BlockEntry *prev_entry = block_entry->prev_ptr; + release_block_entry(user_context, block_entry); + block_entry = prev_entry; + } +} + +void BlockAllocator::destroy(void *user_context) { + BlockEntry *block_entry = block_list.back(); + while (block_entry != nullptr) { + BlockEntry *prev_entry = block_entry->prev_ptr; + destroy_block_entry(user_context, block_entry); + block_entry = prev_entry; + } + block_list.destroy(user_context); +} + +MemoryRegion *BlockAllocator::reserve_memory_region(void *user_context, RegionAllocator *allocator, const MemoryRequest &request) { + MemoryRegion *result = allocator->reserve(user_context, request); + if (result == nullptr) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "BlockAllocator: Failed to allocate region of size (" + << (int32_t)(request.size) << " bytes)!\n"; +#endif + // allocator has enough free space, but not enough contiguous space + // -- collect and try to reallocate + if (allocator->collect(user_context)) { + result = allocator->reserve(user_context, request); + } + } + return result; +} + +BlockAllocator::BlockEntry * +BlockAllocator::find_block_entry(void *user_context, const MemoryProperties &properties, size_t size, bool dedicated) { + BlockEntry *block_entry = nullptr; + for (block_entry = block_list.front(); block_entry != nullptr; block_entry = block_entry->next_ptr) { + + const BlockResource *block = static_cast(block_entry->value); + if (!is_compatible_block(block, properties)) { + continue; + } + + // skip blocks that can't be dedicated to a single allocation + if (dedicated && (block->reserved > 0)) { + continue; + } + + // skip dedicated blocks that are already allocated + if (block->memory.dedicated && (block->reserved > 0)) { + continue; + } + + size_t available = (block->memory.size - block->reserved); + if (available >= size) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "BlockAllocator: find_block_entry (FOUND) (" + << "user_context=" << (void *)(user_context) << " " + << "block_entry=" << (void *)(block_entry) << " " + << "size=" << (uint32_t)size << " " + << "dedicated=" << (dedicated ? "true" : "false") << " " + << "usage=" << halide_memory_usage_name(properties.usage) << " " + << "caching=" << halide_memory_caching_name(properties.caching) << " " + << "visibility=" << halide_memory_visibility_name(properties.visibility) << ") ...\n"; +#endif + break; + } + } + + return block_entry; +} + +BlockAllocator::BlockEntry * +BlockAllocator::reserve_block_entry(void *user_context, const MemoryProperties &properties, size_t size, bool dedicated) { + BlockEntry *block_entry = find_block_entry(user_context, properties, size, dedicated); + if (block_entry == nullptr) { + size_t actual_size = constrain_requested_size(size); + block_entry = create_block_entry(user_context, properties, actual_size, dedicated); + } + + if (block_entry) { + BlockResource *block = static_cast(block_entry->value); + if (block->allocator == nullptr) { + block->allocator = create_region_allocator(user_context, block); + } + } + return block_entry; +} + +RegionAllocator * +BlockAllocator::create_region_allocator(void *user_context, BlockResource *block) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "BlockAllocator: Creating region allocator (" + << "user_context=" << (void *)(user_context) << " " + << "block_resource=" << (void *)(block) << ")...\n"; +#endif + halide_abort_if_false(user_context, block != nullptr); + RegionAllocator *region_allocator = RegionAllocator::create( + user_context, block, {allocators.system, allocators.region}); + + if (region_allocator == nullptr) { + error(user_context) << "BlockAllocator: Failed to create new region allocator!\n"; + return nullptr; + } + + return region_allocator; +} + +void BlockAllocator::destroy_region_allocator(void *user_context, RegionAllocator *region_allocator) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "BlockAllocator: Destroying region allocator (" + << "user_context=" << (void *)(user_context) << " " + << "region_allocator=" << (void *)(region_allocator) << ")...\n"; +#endif + if (region_allocator == nullptr) { + return; + } + RegionAllocator::destroy(user_context, region_allocator); +} + +BlockAllocator::BlockEntry * +BlockAllocator::create_block_entry(void *user_context, const MemoryProperties &properties, size_t size, bool dedicated) { + if (config.maximum_block_count && (block_count() >= config.maximum_block_count)) { + debug(user_context) << "BlockAllocator: No free blocks found! Maximum block count reached (" + << (int32_t)(config.maximum_block_count) << ")!\n"; + return nullptr; + } + + BlockEntry *block_entry = block_list.append(user_context); + if (block_entry == nullptr) { + debug(user_context) << "BlockAllocator: Failed to allocate new block entry!\n"; + return nullptr; + } + +#ifdef DEBUG_RUNTIME + debug(user_context) << "BlockAllocator: Creating block entry (" + << "block_entry=" << (void *)(block_entry) << " " + << "block=" << (void *)(block_entry->value) << " " + << "allocator=" << (void *)(allocators.block.allocate) << ")...\n"; +#endif + + BlockResource *block = static_cast(block_entry->value); + block->memory.size = size; + block->memory.properties = properties; + block->memory.dedicated = dedicated; + block->reserved = 0; + block->allocator = create_region_allocator(user_context, block); + alloc_memory_block(user_context, block); + return block_entry; +} + +void BlockAllocator::release_block_entry(void *user_context, BlockAllocator::BlockEntry *block_entry) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "BlockAllocator: Releasing block entry (" + << "block_entry=" << (void *)(block_entry) << " " + << "block=" << (void *)(block_entry->value) << ")...\n"; +#endif + BlockResource *block = static_cast(block_entry->value); + if (block->allocator) { + block->allocator->release(user_context); + } +} + +void BlockAllocator::destroy_block_entry(void *user_context, BlockAllocator::BlockEntry *block_entry) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "BlockAllocator: Destroying block entry (" + << "block_entry=" << (void *)(block_entry) << " " + << "block=" << (void *)(block_entry->value) << " " + << "deallocator=" << (void *)(allocators.block.deallocate) << ")...\n"; +#endif + BlockResource *block = static_cast(block_entry->value); + if (block->allocator) { + destroy_region_allocator(user_context, block->allocator); + block->allocator = nullptr; + } + free_memory_block(user_context, block); + block_list.remove(user_context, block_entry); +} + +void BlockAllocator::alloc_memory_block(void *user_context, BlockResource *block) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "BlockAllocator: Allocating block (ptr=" << (void *)block << " allocator=" << (void *)allocators.block.allocate << ")...\n"; +#endif + halide_abort_if_false(user_context, allocators.block.allocate != nullptr); + MemoryBlock *memory_block = &(block->memory); + allocators.block.allocate(user_context, memory_block); + block->reserved = 0; +} + +void BlockAllocator::free_memory_block(void *user_context, BlockResource *block) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "BlockAllocator: Deallocating block (ptr=" << (void *)block << " allocator=" << (void *)allocators.block.deallocate << ")...\n"; +#endif + halide_abort_if_false(user_context, allocators.block.deallocate != nullptr); + MemoryBlock *memory_block = &(block->memory); + allocators.block.deallocate(user_context, memory_block); + block->reserved = 0; + block->memory.size = 0; +} + +size_t BlockAllocator::constrain_requested_size(size_t size) const { + size_t actual_size = size; + if (config.minimum_block_size) { + actual_size = ((actual_size < config.minimum_block_size) ? + config.minimum_block_size : + actual_size); + } + if (config.maximum_block_size) { + actual_size = ((actual_size > config.maximum_block_size) ? + config.maximum_block_size : + actual_size); + } + return actual_size; +} + +bool BlockAllocator::is_compatible_block(const BlockResource *block, const MemoryProperties &properties) const { + if (properties.caching != MemoryCaching::DefaultCaching) { + if (properties.caching != block->memory.properties.caching) { + return false; + } + } + + if (properties.visibility != MemoryVisibility::DefaultVisibility) { + if (properties.visibility != block->memory.properties.visibility) { + return false; + } + } + + if (properties.usage != MemoryUsage::DefaultUsage) { + if (properties.usage != block->memory.properties.usage) { + return false; + } + } + + return true; +} + +const BlockAllocator::MemoryAllocators &BlockAllocator::current_allocators() const { + return allocators; +} + +const BlockAllocator::Config &BlockAllocator::current_config() const { + return config; +} + +const BlockAllocator::Config &BlockAllocator::default_config() const { + static Config result; + return result; +} + +size_t BlockAllocator::block_count() const { + return block_list.size(); +} + +// -- + +} // namespace Internal +} // namespace Runtime +} // namespace Halide + +#endif // HALIDE_RUNTIME_BLOCK_ALLOCATOR_H diff --git a/src/runtime/internal/block_storage.h b/src/runtime/internal/block_storage.h new file mode 100644 index 000000000000..a552c0a438d9 --- /dev/null +++ b/src/runtime/internal/block_storage.h @@ -0,0 +1,429 @@ +#ifndef HALIDE_RUNTIME_BLOCK_STORAGE_H +#define HALIDE_RUNTIME_BLOCK_STORAGE_H + +#include "memory_resources.h" + +namespace Halide { +namespace Runtime { +namespace Internal { + +// Dynamically resizable array for block storage (eg plain old data) +// -- No usage of constructors/destructors for value type +// -- Assumes all elements stored are uniformly the same fixed size +// -- Allocations are done in blocks of a fixed size +// -- Implementation uses memcpy/memmove for copying +// -- Customizable allocator ... default uses NativeSystemAllocator +class BlockStorage { +public: + static constexpr size_t default_capacity = 32; // smallish + + // Configurable parameters + struct Config { + uint32_t entry_size = 1; // bytes per entry + uint32_t block_size = 32; // bytes per each allocation block + uint32_t minimum_capacity = default_capacity; + }; + + BlockStorage(void *user_context, const Config &cfg, const SystemMemoryAllocatorFns &sma = default_allocator()); + BlockStorage(const BlockStorage &other); + ~BlockStorage(); + + void initialize(void *user_context, const Config &cfg, const SystemMemoryAllocatorFns &sma = default_allocator()); + + BlockStorage &operator=(const BlockStorage &other); + bool operator==(const BlockStorage &other) const; + bool operator!=(const BlockStorage &other) const; + + void reserve(void *user_context, size_t capacity, bool free_existing = false); + void resize(void *user_context, size_t entry_count, bool realloc = true); + + void assign(void *user_context, size_t index, const void *entry_ptr); + void insert(void *user_context, size_t index, const void *entry_ptr); + void prepend(void *user_context, const void *entry_ptr); + void append(void *user_context, const void *entry_ptr); + void remove(void *user_context, size_t index); + + void fill(void *user_context, const void *array, size_t array_size); + void insert(void *user_context, size_t index, const void *array, size_t array_size); + void replace(void *user_context, size_t index, const void *array, size_t array_size); + void prepend(void *user_context, const void *array, size_t array_size); + void append(void *user_context, const void *array, size_t array_size); + void remove(void *user_context, size_t index, size_t entry_count); + + void pop_front(void *user_context); + void pop_back(void *user_context); + void shrink_to_fit(void *user_context); + void clear(void *user_context); + void destroy(void *user_context); + + bool empty() const; + size_t stride() const; + size_t size() const; + + void *operator[](size_t index); ///< logical entry index (returns ptr = data() + (index * stride()) + const void *operator[](size_t index) const; + + void *data(); + void *front(); + void *back(); + + const void *data() const; + const void *front() const; + const void *back() const; + + const Config ¤t_config() const; + static const Config &default_config(); + + const SystemMemoryAllocatorFns ¤t_allocator() const; + static const SystemMemoryAllocatorFns &default_allocator(); + +private: + void allocate(void *user_context, size_t capacity); + + void *ptr = nullptr; + size_t count = 0; + size_t capacity = 0; + Config config; + SystemMemoryAllocatorFns allocator; +}; + +BlockStorage::BlockStorage(void *user_context, const Config &cfg, const SystemMemoryAllocatorFns &sma) + : config(cfg), allocator(sma) { + halide_abort_if_false(user_context, config.entry_size != 0); + halide_abort_if_false(user_context, allocator.allocate != nullptr); + halide_abort_if_false(user_context, allocator.deallocate != nullptr); + if (config.minimum_capacity) { + reserve(user_context, config.minimum_capacity); + } +} + +BlockStorage::BlockStorage(const BlockStorage &other) + : BlockStorage(nullptr, other.config, other.allocator) { + if (other.count) { + resize(nullptr, other.count); + memcpy(this->ptr, other.ptr, count * config.entry_size); + } +} + +BlockStorage::~BlockStorage() { + destroy(nullptr); +} + +void BlockStorage::destroy(void *user_context) { + halide_abort_if_false(user_context, allocator.deallocate != nullptr); + if (ptr != nullptr) { + allocator.deallocate(user_context, ptr); + } + capacity = count = 0; + ptr = nullptr; +} + +void BlockStorage::initialize(void *user_context, const Config &cfg, const SystemMemoryAllocatorFns &sma) { + allocator = sma; + config = cfg; + capacity = count = 0; + ptr = nullptr; + if (config.minimum_capacity) { + reserve(user_context, config.minimum_capacity); + } +} + +BlockStorage &BlockStorage::operator=(const BlockStorage &other) { + if (&other != this) { + config = other.config; + resize(nullptr, other.count); + if (count != 0 && other.ptr != nullptr) { + memcpy(ptr, other.ptr, count * config.entry_size); + } + } + return *this; +} + +bool BlockStorage::operator==(const BlockStorage &other) const { + if (config.entry_size != other.config.entry_size) { + return false; + } + if (count != other.count) { + return false; + } + return memcmp(this->ptr, other.ptr, this->size() * config.entry_size) == 0; +} + +bool BlockStorage::operator!=(const BlockStorage &other) const { + return !(*this == other); +} + +void BlockStorage::fill(void *user_context, const void *array, size_t array_size) { + if (array_size != 0) { + resize(user_context, array_size); + memcpy(this->ptr, array, array_size * config.entry_size); + count = array_size; + } +} + +void BlockStorage::assign(void *user_context, size_t index, const void *entry_ptr) { + replace(user_context, index, entry_ptr, 1); +} + +void BlockStorage::prepend(void *user_context, const void *entry_ptr) { + insert(user_context, 0, entry_ptr, 1); +} + +void BlockStorage::append(void *user_context, const void *entry_ptr) { + append(user_context, entry_ptr, 1); +} + +void BlockStorage::pop_front(void *user_context) { + halide_debug_assert(user_context, count > 0); + remove(user_context, 0); +} + +void BlockStorage::pop_back(void *user_context) { + halide_debug_assert(user_context, count > 0); + resize(user_context, size() - 1); +} + +void BlockStorage::clear(void *user_context) { + resize(user_context, 0); +} + +void BlockStorage::reserve(void *user_context, size_t new_capacity, bool free_existing) { + new_capacity = max(new_capacity, count); + + if ((new_capacity < capacity) && !free_existing) { + new_capacity = capacity; + } + + allocate(user_context, new_capacity); +} + +void BlockStorage::resize(void *user_context, size_t entry_count, bool realloc) { + size_t current_size = capacity; + size_t requested_size = entry_count; + size_t minimum_size = config.minimum_capacity; + size_t actual_size = current_size; + count = requested_size; + + // increase capacity upto 1.5x existing (or at least min_capacity) + if (requested_size > current_size) { + actual_size = max(requested_size, max(current_size * 3 / 2, minimum_size)); + } else if (!realloc) { + return; + } + +#if DEBUG + debug(user_context) << "BlockStorage: Resize (" + << "requested_size=" << (int32_t)requested_size << " " + << "current_size=" << (int32_t)current_size << " " + << "minimum_size=" << (int32_t)minimum_size << " " + << "actual_size=" << (int32_t)actual_size << " " + << "entry_size=" << (int32_t)config.entry_size << " " + << "realloc=" << (realloc ? "true" : "false") << ")...\n"; +#endif + + allocate(user_context, actual_size); +} + +void BlockStorage::shrink_to_fit(void *user_context) { + if (capacity > count) { + void *new_ptr = nullptr; + if (count > 0) { + size_t actual_bytes = count * config.entry_size; + new_ptr = allocator.allocate(user_context, actual_bytes); + memcpy(new_ptr, ptr, actual_bytes); + } + allocator.deallocate(user_context, ptr); + capacity = count; + ptr = new_ptr; + } +} + +void BlockStorage::insert(void *user_context, size_t index, const void *entry_ptr) { + insert(user_context, index, entry_ptr, 1); +} + +void BlockStorage::remove(void *user_context, size_t index) { + remove(user_context, index, 1); +} + +void BlockStorage::remove(void *user_context, size_t index, size_t entry_count) { + halide_debug_assert(user_context, index < count); + const size_t last_index = size(); + if (index < (last_index - entry_count)) { + size_t dst_offset = index * config.entry_size; + size_t src_offset = (index + entry_count) * config.entry_size; + size_t bytes = (last_index - index - entry_count) * config.entry_size; + +#if DEBUG + debug(0) << "BlockStorage: Remove (" + << "index=" << (int32_t)index << " " + << "entry_count=" << (int32_t)entry_count << " " + << "entry_size=" << (int32_t)config.entry_size << " " + << "last_index=" << (int32_t)last_index << " " + << "src_offset=" << (int32_t)src_offset << " " + << "dst_offset=" << (int32_t)dst_offset << " " + << "bytes=" << (int32_t)bytes << ")...\n"; +#endif + void *dst_ptr = offset_address(ptr, dst_offset); + void *src_ptr = offset_address(ptr, src_offset); + memmove(dst_ptr, src_ptr, bytes); + } + resize(user_context, last_index - entry_count); +} + +void BlockStorage::replace(void *user_context, size_t index, const void *array, size_t array_size) { + halide_debug_assert(user_context, index < count); + size_t offset = index * config.entry_size; + size_t remaining = count - index; + +#if DEBUG + debug(0) << "BlockStorage: Replace (" + << "index=" << (int32_t)index << " " + << "array_size=" << (int32_t)array_size << " " + << "entry_size=" << (int32_t)config.entry_size << " " + << "offset=" << (int32_t)offset << " " + << "remaining=" << (int32_t)remaining << " " + << "capacity=" << (int32_t)capacity << ")...\n"; +#endif + + halide_debug_assert(user_context, remaining > 0); + size_t copy_count = min(remaining, array_size); + void *dst_ptr = offset_address(ptr, offset); + memcpy(dst_ptr, array, copy_count * config.entry_size); + count = max(count, index + copy_count); +} + +void BlockStorage::insert(void *user_context, size_t index, const void *array, size_t array_size) { + halide_debug_assert(user_context, index <= count); + const size_t last_index = size(); + resize(user_context, last_index + array_size); + if (index < last_index) { + size_t src_offset = index * config.entry_size; + size_t dst_offset = (index + array_size) * config.entry_size; + size_t bytes = (last_index - index) * config.entry_size; + void *src_ptr = offset_address(ptr, src_offset); + void *dst_ptr = offset_address(ptr, dst_offset); + memmove(dst_ptr, src_ptr, bytes); + } + replace(user_context, index, array, array_size); +} + +void BlockStorage::prepend(void *user_context, const void *array, size_t array_size) { + insert(user_context, 0, array, array_size); +} + +void BlockStorage::append(void *user_context, const void *array, size_t array_size) { + const size_t last_index = size(); + insert(user_context, last_index, array, array_size); +} + +bool BlockStorage::empty() const { + return count == 0; +} + +size_t BlockStorage::size() const { + return count; +} + +size_t BlockStorage::stride() const { + return config.entry_size; +} + +void *BlockStorage::operator[](size_t index) { + halide_debug_assert(nullptr, index < capacity); + return offset_address(ptr, index * config.entry_size); +} + +const void *BlockStorage::operator[](size_t index) const { + halide_debug_assert(nullptr, index < capacity); + return offset_address(ptr, index * config.entry_size); +} + +void *BlockStorage::data() { + return ptr; +} + +void *BlockStorage::front() { + halide_debug_assert(nullptr, count > 0); + return ptr; +} + +void *BlockStorage::back() { + halide_debug_assert(nullptr, count > 0); + size_t index = count - 1; + return offset_address(ptr, index * config.entry_size); +} + +const void *BlockStorage::data() const { + return ptr; +} + +const void *BlockStorage::front() const { + halide_debug_assert(nullptr, count > 0); + return ptr; +} + +const void *BlockStorage::back() const { + halide_debug_assert(nullptr, count > 0); + size_t index = count - 1; + return offset_address(ptr, index * config.entry_size); +} + +void BlockStorage::allocate(void *user_context, size_t new_capacity) { + if (new_capacity != capacity) { + halide_abort_if_false(user_context, allocator.allocate != nullptr); + size_t requested_bytes = new_capacity * config.entry_size; + size_t block_size = max(config.block_size, config.entry_size); + size_t block_count = (requested_bytes / block_size); + block_count += (requested_bytes % block_size) ? 1 : 0; + size_t alloc_size = block_count * block_size; +#if DEBUG + debug(0) << "BlockStorage: Allocating (" + << "requested_bytes=" << (int32_t)requested_bytes << " " + << "block_size=" << (int32_t)block_size << " " + << "block_count=" << (int32_t)block_count << " " + << "alloc_size=" << (int32_t)alloc_size << ") ...\n"; +#endif + void *new_ptr = alloc_size ? allocator.allocate(user_context, alloc_size) : nullptr; + if (count != 0 && ptr != nullptr && new_ptr != nullptr) { + memcpy(new_ptr, ptr, count * config.entry_size); + } + if (ptr != nullptr) { + halide_abort_if_false(user_context, allocator.deallocate != nullptr); + allocator.deallocate(user_context, ptr); + } + capacity = new_capacity; + ptr = new_ptr; + } +} + +const SystemMemoryAllocatorFns & +BlockStorage::current_allocator() const { + return this->allocator; +} + +const BlockStorage::Config & +BlockStorage::default_config() { + static Config default_cfg; + return default_cfg; +} + +const BlockStorage::Config & +BlockStorage::current_config() const { + return this->config; +} + +const SystemMemoryAllocatorFns & +BlockStorage::default_allocator() { + static SystemMemoryAllocatorFns native_allocator = { + native_system_malloc, native_system_free}; + return native_allocator; +} + +// -- + +} // namespace Internal +} // namespace Runtime +} // namespace Halide + +#endif // HALIDE_RUNTIME_BLOCK_STORAGE_H diff --git a/src/runtime/internal/linked_list.h b/src/runtime/internal/linked_list.h new file mode 100644 index 000000000000..1ddccbb5ff25 --- /dev/null +++ b/src/runtime/internal/linked_list.h @@ -0,0 +1,337 @@ +#ifndef HALIDE_RUNTIME_LINKED_LIST_H +#define HALIDE_RUNTIME_LINKED_LIST_H + +#include "memory_arena.h" + +namespace Halide { +namespace Runtime { +namespace Internal { + +// Doubly linked list container +// -- Implemented using MemoryArena for allocation +class LinkedList { +public: + // Disable copy support + LinkedList(const LinkedList &) = delete; + LinkedList &operator=(const LinkedList &) = delete; + + // Default initial capacity + static constexpr uint32_t default_capacity = uint32_t(32); // smallish + + // List entry + struct EntryType { + void *value = nullptr; + EntryType *prev_ptr = nullptr; + EntryType *next_ptr = nullptr; + }; + + LinkedList(void *user_context, uint32_t entry_size, uint32_t capacity = default_capacity, + const SystemMemoryAllocatorFns &allocator = default_allocator()); + ~LinkedList(); + + void initialize(void *user_context, uint32_t entry_size, uint32_t capacity = default_capacity, + const SystemMemoryAllocatorFns &allocator = default_allocator()); + + EntryType *front(); + EntryType *back(); + + const EntryType *front() const; + const EntryType *back() const; + + EntryType *prepend(void *user_context); + EntryType *prepend(void *user_context, const void *value); + + EntryType *append(void *user_context); + EntryType *append(void *user_context, const void *value); + + void pop_front(void *user_context); + void pop_back(void *user_context); + + EntryType *insert_before(void *user_context, EntryType *entry_ptr); + EntryType *insert_before(void *user_context, EntryType *entry_ptr, const void *value); + + EntryType *insert_after(void *user_context, EntryType *entry_ptr); + EntryType *insert_after(void *user_context, EntryType *entry_ptr, const void *value); + + void remove(void *user_context, EntryType *entry_ptr); + void clear(void *user_context); + void destroy(void *user_context); + + size_t size() const; + bool empty() const; + + const SystemMemoryAllocatorFns ¤t_allocator() const; + static const SystemMemoryAllocatorFns &default_allocator(); + +private: + EntryType *reserve(void *user_context); + void reclaim(void *user_context, EntryType *entry_ptr); + + MemoryArena *link_arena = nullptr; + MemoryArena *data_arena = nullptr; + EntryType *front_ptr = nullptr; + EntryType *back_ptr = nullptr; + size_t entry_count = 0; +}; + +LinkedList::LinkedList(void *user_context, uint32_t entry_size, uint32_t capacity, + const SystemMemoryAllocatorFns &sma) { + uint32_t arena_capacity = max(capacity, MemoryArena::default_capacity); + link_arena = MemoryArena::create(user_context, {sizeof(EntryType), arena_capacity, 0}, sma); + data_arena = MemoryArena::create(user_context, {entry_size, arena_capacity, 0}, sma); + front_ptr = nullptr; + back_ptr = nullptr; + entry_count = 0; +} + +LinkedList::~LinkedList() { + destroy(nullptr); +} + +void LinkedList::initialize(void *user_context, uint32_t entry_size, uint32_t capacity, + const SystemMemoryAllocatorFns &sma) { + uint32_t arena_capacity = max(capacity, MemoryArena::default_capacity); + link_arena = MemoryArena::create(user_context, {sizeof(EntryType), arena_capacity, 0}, sma); + data_arena = MemoryArena::create(user_context, {entry_size, arena_capacity, 0}, sma); + front_ptr = nullptr; + back_ptr = nullptr; + entry_count = 0; +} + +void LinkedList::destroy(void *user_context) { + clear(nullptr); + if (link_arena) { + MemoryArena::destroy(nullptr, link_arena); + } + if (data_arena) { + MemoryArena::destroy(nullptr, data_arena); + } + link_arena = nullptr; + data_arena = nullptr; + front_ptr = nullptr; + back_ptr = nullptr; + entry_count = 0; +} + +typename LinkedList::EntryType *LinkedList::front() { + return front_ptr; +} + +typename LinkedList::EntryType *LinkedList::back() { + return back_ptr; +} + +const typename LinkedList::EntryType *LinkedList::front() const { + return front_ptr; +} + +const typename LinkedList::EntryType *LinkedList::back() const { + return back_ptr; +} + +typename LinkedList::EntryType * +LinkedList::prepend(void *user_context) { + EntryType *entry_ptr = reserve(user_context); + if (empty()) { + front_ptr = entry_ptr; + back_ptr = entry_ptr; + entry_count = 1; + } else { + entry_ptr->next_ptr = front_ptr; + front_ptr->prev_ptr = entry_ptr; + front_ptr = entry_ptr; + ++entry_count; + } + return entry_ptr; +} + +typename LinkedList::EntryType * +LinkedList::append(void *user_context) { + EntryType *entry_ptr = reserve(user_context); + if (empty()) { + front_ptr = entry_ptr; + back_ptr = entry_ptr; + entry_count = 1; + } else { + entry_ptr->prev_ptr = back_ptr; + back_ptr->next_ptr = entry_ptr; + back_ptr = entry_ptr; + ++entry_count; + } + return entry_ptr; +} + +typename LinkedList::EntryType * +LinkedList::prepend(void *user_context, const void *value) { + EntryType *entry_ptr = prepend(user_context); + memcpy(entry_ptr->value, value, data_arena->current_config().entry_size); + return entry_ptr; +} + +typename LinkedList::EntryType * +LinkedList::append(void *user_context, const void *value) { + EntryType *entry_ptr = append(user_context); + memcpy(entry_ptr->value, value, data_arena->current_config().entry_size); + return entry_ptr; +} + +void LinkedList::pop_front(void *user_context) { + halide_abort_if_false(user_context, (entry_count > 0)); + EntryType *remove_ptr = front_ptr; + EntryType *next_ptr = remove_ptr->next_ptr; + if (next_ptr != nullptr) { + next_ptr->prev_ptr = nullptr; + } + front_ptr = next_ptr; + reclaim(user_context, remove_ptr); + --entry_count; +} + +void LinkedList::pop_back(void *user_context) { + halide_abort_if_false(user_context, (entry_count > 0)); + EntryType *remove_ptr = back_ptr; + EntryType *prev_ptr = remove_ptr->prev_ptr; + if (prev_ptr != nullptr) { + prev_ptr->next_ptr = nullptr; + } + back_ptr = prev_ptr; + reclaim(user_context, remove_ptr); + --entry_count; +} + +void LinkedList::clear(void *user_context) { + if (empty() == false) { + EntryType *remove_ptr = back_ptr; + while (remove_ptr != nullptr) { + EntryType *prev_ptr = remove_ptr->prev_ptr; + reclaim(user_context, remove_ptr); + remove_ptr = prev_ptr; + } + front_ptr = nullptr; + back_ptr = nullptr; + entry_count = 0; + } +} + +void LinkedList::remove(void *user_context, EntryType *entry_ptr) { + halide_abort_if_false(user_context, (entry_ptr != nullptr)); + halide_abort_if_false(user_context, (entry_count > 0)); + + if (entry_ptr->prev_ptr != nullptr) { + entry_ptr->prev_ptr->next_ptr = entry_ptr->next_ptr; + } else { + halide_abort_if_false(user_context, (front_ptr == entry_ptr)); + front_ptr = entry_ptr->next_ptr; + } + + if (entry_ptr->next_ptr != nullptr) { + entry_ptr->next_ptr->prev_ptr = entry_ptr->prev_ptr; + } else { + halide_abort_if_false(user_context, (back_ptr == entry_ptr)); + back_ptr = entry_ptr->prev_ptr; + } + + reclaim(user_context, entry_ptr); + --entry_count; +} + +typename LinkedList::EntryType * +LinkedList::insert_before(void *user_context, EntryType *entry_ptr) { + if (entry_ptr != nullptr) { + EntryType *prev_ptr = entry_ptr->prev_ptr; + EntryType *new_ptr = reserve(user_context); + new_ptr->prev_ptr = prev_ptr; + new_ptr->next_ptr = entry_ptr; + entry_ptr->prev_ptr = new_ptr; + if (prev_ptr != nullptr) { + prev_ptr->next_ptr = new_ptr; + } else { + halide_abort_if_false(user_context, (front_ptr == entry_ptr)); + front_ptr = new_ptr; + } + ++entry_count; + return new_ptr; + } else { + return append(user_context); + } +} + +typename LinkedList::EntryType * +LinkedList::insert_after(void *user_context, EntryType *entry_ptr) { + if (entry_ptr != nullptr) { + EntryType *next_ptr = entry_ptr->next_ptr; + EntryType *new_ptr = reserve(user_context); + new_ptr->next_ptr = next_ptr; + new_ptr->prev_ptr = entry_ptr; + entry_ptr->next_ptr = new_ptr; + if (next_ptr != nullptr) { + next_ptr->prev_ptr = new_ptr; + } else { + halide_abort_if_false(user_context, (back_ptr == entry_ptr)); + back_ptr = new_ptr; + } + ++entry_count; + return new_ptr; + } else { + return prepend(user_context); + } +} + +typename LinkedList::EntryType * +LinkedList::insert_before(void *user_context, EntryType *entry_ptr, const void *value) { + EntryType *new_ptr = insert_before(user_context, entry_ptr); + memcpy(new_ptr->value, value, data_arena->current_config().entry_size); + return new_ptr; +} + +typename LinkedList::EntryType * +LinkedList::insert_after(void *user_context, EntryType *entry_ptr, const void *value) { + EntryType *new_ptr = insert_after(user_context, entry_ptr); + memcpy(new_ptr->value, value, data_arena->current_config().entry_size); + return new_ptr; +} + +size_t LinkedList::size() const { + return entry_count; +} + +bool LinkedList::empty() const { + return entry_count == 0; +} + +const SystemMemoryAllocatorFns & +LinkedList::current_allocator() const { + return link_arena->current_allocator(); +} + +const SystemMemoryAllocatorFns & +LinkedList::default_allocator() { + return MemoryArena::default_allocator(); +} + +typename LinkedList::EntryType * +LinkedList::reserve(void *user_context) { + EntryType *entry_ptr = static_cast( + link_arena->reserve(user_context, true)); + entry_ptr->value = data_arena->reserve(user_context, true); + entry_ptr->next_ptr = nullptr; + entry_ptr->prev_ptr = nullptr; + return entry_ptr; +} + +void LinkedList::reclaim(void *user_context, EntryType *entry_ptr) { + void *value_ptr = entry_ptr->value; + entry_ptr->value = nullptr; + entry_ptr->next_ptr = nullptr; + entry_ptr->prev_ptr = nullptr; + data_arena->reclaim(user_context, value_ptr); + link_arena->reclaim(user_context, entry_ptr); +} + +// -- + +} // namespace Internal +} // namespace Runtime +} // namespace Halide + +#endif // HALIDE_RUNTIME_LINKED_LIST_H diff --git a/src/runtime/internal/memory_arena.h b/src/runtime/internal/memory_arena.h new file mode 100644 index 000000000000..3f19ce45f71a --- /dev/null +++ b/src/runtime/internal/memory_arena.h @@ -0,0 +1,312 @@ +#ifndef HALIDE_RUNTIME_MEMORY_ARENA_H +#define HALIDE_RUNTIME_MEMORY_ARENA_H + +#include "block_storage.h" + +namespace Halide { +namespace Runtime { +namespace Internal { + +// -- +// Memory Arena class for region based allocations and caching of same-type data +// -- Implementation uses block_storage, and internally manages lists of allocated entries +// -- Customizable allocator (defaults to BlockStorage::default_allocator()) +// -- Not thread safe ... locking must be done by client +// +class MemoryArena { +public: + // Disable copy constructors and assignment + MemoryArena(const MemoryArena &) = delete; + MemoryArena &operator=(const MemoryArena &) = delete; + + // Default initial capacity + static constexpr uint32_t default_capacity = uint32_t(32); // smallish + + // Configurable parameters + struct Config { + uint32_t entry_size = 1; + uint32_t minimum_block_capacity = default_capacity; + uint32_t maximum_block_count = 0; + }; + + MemoryArena(void *user_context, const Config &config = default_config(), + const SystemMemoryAllocatorFns &allocator = default_allocator()); + + ~MemoryArena(); + + // Factory methods for creation / destruction + static MemoryArena *create(void *user_context, const Config &config, const SystemMemoryAllocatorFns &allocator = default_allocator()); + static void destroy(void *user_context, MemoryArena *arena); + + // Initialize a newly created instance + void initialize(void *user_context, const Config &config, + const SystemMemoryAllocatorFns &allocator = default_allocator()); + + // Public interface methods + void *reserve(void *user_context, bool initialize = false); + void reclaim(void *user_context, void *ptr); + bool collect(void *user_context); //< returns true if any blocks were removed + void destroy(void *user_context); + + // Access methods + const Config ¤t_config() const; + static const Config &default_config(); + + const SystemMemoryAllocatorFns ¤t_allocator() const; + static const SystemMemoryAllocatorFns &default_allocator(); + +private: + // Sentinal invalid entry value + static const uint32_t invalid_entry = uint32_t(-1); + + // Each block contains: + // - an array of entries + // - an array of indices (for the free list) + // - an array of status flags (indicating usage) + // - free index points to next available entry for the block (or invalid_entry if block is full) + struct Block { + void *entries = nullptr; + uint32_t *indices = nullptr; + AllocationStatus *status = nullptr; + uint32_t capacity = 0; + uint32_t free_index = 0; + }; + + Block *create_block(void *user_context); + bool collect_block(void *user_context, Block *block); //< returns true if any blocks were removed + void destroy_block(void *user_context, Block *block); + Block *lookup_block(void *user_context, uint32_t index); + + void *create_entry(void *user_context, Block *block, uint32_t index); + void destroy_entry(void *user_context, Block *block, uint32_t index); + void *lookup_entry(void *user_context, Block *block, uint32_t index); + + Config config; + BlockStorage blocks; +}; + +MemoryArena::MemoryArena(void *user_context, + const Config &cfg, + const SystemMemoryAllocatorFns &alloc) + : config(cfg), + blocks(user_context, {sizeof(MemoryArena::Block), 32, 32}, alloc) { + halide_debug_assert(user_context, config.minimum_block_capacity > 1); +} + +MemoryArena::~MemoryArena() { + destroy(nullptr); +} + +MemoryArena *MemoryArena::create(void *user_context, const Config &cfg, const SystemMemoryAllocatorFns &system_allocator) { + halide_abort_if_false(user_context, system_allocator.allocate != nullptr); + MemoryArena *result = reinterpret_cast( + system_allocator.allocate(user_context, sizeof(MemoryArena))); + + if (result == nullptr) { + halide_error(user_context, "MemoryArena: Failed to create instance! Out of memory!\n"); + return nullptr; + } + + result->initialize(user_context, cfg, system_allocator); + return result; +} + +void MemoryArena::destroy(void *user_context, MemoryArena *instance) { + halide_abort_if_false(user_context, instance != nullptr); + const SystemMemoryAllocatorFns &system_allocator = instance->blocks.current_allocator(); + instance->destroy(user_context); + halide_abort_if_false(user_context, system_allocator.deallocate != nullptr); + system_allocator.deallocate(user_context, instance); +} + +void MemoryArena::initialize(void *user_context, + const Config &cfg, + const SystemMemoryAllocatorFns &system_allocator) { + config = cfg; + blocks.initialize(user_context, {sizeof(MemoryArena::Block), 32, 32}, system_allocator); + halide_debug_assert(user_context, config.minimum_block_capacity > 1); +} + +void MemoryArena::destroy(void *user_context) { + if (!blocks.empty()) { + for (size_t i = blocks.size(); i--;) { + Block *block = lookup_block(user_context, i); + halide_abort_if_false(user_context, block != nullptr); + destroy_block(user_context, block); + } + } + blocks.destroy(user_context); +} + +bool MemoryArena::collect(void *user_context) { + bool result = false; + for (size_t i = blocks.size(); i--;) { + Block *block = lookup_block(user_context, i); + halide_abort_if_false(user_context, block != nullptr); + if (collect_block(user_context, block)) { + blocks.remove(user_context, i); + result = true; + } + } + return result; +} + +void *MemoryArena::reserve(void *user_context, bool initialize) { + // Scan blocks for a free entry + for (size_t i = blocks.size(); i--;) { + Block *block = lookup_block(user_context, i); + halide_abort_if_false(user_context, block != nullptr); + if (block->free_index != invalid_entry) { + return create_entry(user_context, block, block->free_index); + } + } + + if (config.maximum_block_count && (blocks.size() >= config.maximum_block_count)) { + halide_error(user_context, "MemoryArena: Failed to reserve new entry! Maxmimum blocks reached!\n"); + return nullptr; + } + + // All blocks full ... create a new one + uint32_t index = 0; + Block *block = create_block(user_context); + void *entry_ptr = create_entry(user_context, block, index); + + // Optionally clear the allocation if requested + if (initialize) { + memset(entry_ptr, 0, config.entry_size); + } + return entry_ptr; +} + +void MemoryArena::reclaim(void *user_context, void *entry_ptr) { + for (size_t i = blocks.size(); i--;) { + Block *block = lookup_block(user_context, i); + halide_abort_if_false(user_context, block != nullptr); + + // is entry_ptr in the address range of this block. + uint8_t *offset_ptr = static_cast(entry_ptr); + uint8_t *base_ptr = static_cast(block->entries); + uint8_t *end_ptr = static_cast(offset_address(block->entries, block->capacity * config.entry_size)); + if ((entry_ptr >= base_ptr) && (entry_ptr < end_ptr)) { + const uint32_t offset = static_cast(offset_ptr - base_ptr); + const uint32_t index = offset / config.entry_size; + destroy_entry(user_context, block, index); + return; + } + } + halide_error(user_context, "MemoryArena: Pointer address doesn't belong to this memory pool!\n"); +} + +typename MemoryArena::Block *MemoryArena::create_block(void *user_context) { + // resize capacity starting with initial up to 1.5 last capacity + uint32_t new_capacity = config.minimum_block_capacity; + if (!blocks.empty()) { + const Block *last_block = static_cast(blocks.back()); + new_capacity = (last_block->capacity * 3 / 2); + } + + halide_abort_if_false(user_context, current_allocator().allocate != nullptr); + void *new_entries = current_allocator().allocate(user_context, config.entry_size * new_capacity); + memset(new_entries, 0, config.entry_size * new_capacity); + + uint32_t *new_indices = (uint32_t *)current_allocator().allocate(user_context, sizeof(uint32_t) * new_capacity); + AllocationStatus *new_status = (AllocationStatus *)current_allocator().allocate(user_context, sizeof(AllocationStatus) * new_capacity); + + for (uint32_t i = 0; i < new_capacity - 1; ++i) { + new_indices[i] = i + 1; // singly-linked list of all free entries in the block + new_status[i] = AllocationStatus::Available; // usage status + } + + new_indices[new_capacity - 1] = invalid_entry; + new_status[new_capacity - 1] = AllocationStatus::InvalidStatus; + + const Block new_block = {new_entries, new_indices, new_status, new_capacity, 0}; + blocks.append(user_context, &new_block); + return static_cast(blocks.back()); +} + +void MemoryArena::destroy_block(void *user_context, Block *block) { + halide_abort_if_false(user_context, block != nullptr); + if (block->entries != nullptr) { + halide_abort_if_false(user_context, current_allocator().deallocate != nullptr); + current_allocator().deallocate(user_context, block->entries); + current_allocator().deallocate(user_context, block->indices); + current_allocator().deallocate(user_context, block->status); + block->entries = nullptr; + block->indices = nullptr; + block->status = nullptr; + } +} + +bool MemoryArena::collect_block(void *user_context, Block *block) { + halide_abort_if_false(user_context, block != nullptr); + if (block->entries != nullptr) { + bool can_collect = true; + for (size_t i = block->capacity; i--;) { + if (block->status[i] == AllocationStatus::InUse) { + can_collect = false; + break; + } + } + if (can_collect) { + destroy_block(user_context, block); + return true; + } + } + return false; +} + +MemoryArena::Block *MemoryArena::lookup_block(void *user_context, uint32_t index) { + return static_cast(blocks[index]); +} + +void *MemoryArena::lookup_entry(void *user_context, Block *block, uint32_t index) { + halide_abort_if_false(user_context, block != nullptr); + halide_abort_if_false(user_context, block->entries != nullptr); + return offset_address(block->entries, index * config.entry_size); +} + +void *MemoryArena::create_entry(void *user_context, Block *block, uint32_t index) { + void *entry_ptr = lookup_entry(user_context, block, index); + block->free_index = block->indices[index]; + block->status[index] = AllocationStatus::InUse; +#if DEBUG_RUNTIME + memset(entry_ptr, 0, config.entry_size); +#endif + return entry_ptr; +} + +void MemoryArena::destroy_entry(void *user_context, Block *block, uint32_t index) { + block->status[index] = AllocationStatus::Available; + block->indices[index] = block->free_index; + block->free_index = index; +} + +const typename MemoryArena::Config & +MemoryArena::current_config() const { + return config; +} + +const typename MemoryArena::Config & +MemoryArena::default_config() { + static Config result; + return result; +} + +const SystemMemoryAllocatorFns & +MemoryArena::current_allocator() const { + return blocks.current_allocator(); +} + +const SystemMemoryAllocatorFns & +MemoryArena::default_allocator() { + return BlockStorage::default_allocator(); +} + +// -- + +} // namespace Internal +} // namespace Runtime +} // namespace Halide + +#endif // HALIDE_RUNTIME_MEMORY_ARENA_H diff --git a/src/runtime/internal/memory_resources.h b/src/runtime/internal/memory_resources.h new file mode 100644 index 000000000000..455ce43ab277 --- /dev/null +++ b/src/runtime/internal/memory_resources.h @@ -0,0 +1,281 @@ +#ifndef HALIDE_RUNTIME_MEMORY_RESOURCES_H +#define HALIDE_RUNTIME_MEMORY_RESOURCES_H + +namespace Halide { +namespace Runtime { +namespace Internal { + +// -- + +// Hint for allocation usage indicating whether or not the resource +// is in use, available, or dedicated (and can't be split or shared) +enum class AllocationStatus { + InvalidStatus, + InUse, + Available, + Purgeable, + Dedicated +}; + +// Hint for allocation requests indicating intended usage +// required between host and device address space mappings +enum class MemoryVisibility { + InvalidVisibility, //< invalid enum value + HostOnly, //< host local + DeviceOnly, //< device local + DeviceToHost, //< transfer from device to host + HostToDevice, //< transfer from host to device + DefaultVisibility, //< default visibility (use any valid visibility -- unable to determine prior to usage) +}; + +// Hint for allocation requests indicating intended update +// frequency for modifying the contents of the allocation +enum class MemoryUsage { + InvalidUsage, //< invalid enum value + StaticStorage, //< intended for static storage, whereby the contents will be set once and remain unchanged + DynamicStorage, //< intended for dyanmic storage, whereby the contents will be set frequently and change constantly + UniformStorage, //< intended for fast & small fixed read-only uniform storage (intended for passing shader parameters), whereby the contents will be set once and remain unchanged + TransferSrc, //< intended for staging storage updates, whereby the contents will be used as the source of a transfer + TransferDst, //< intended for staging storage updates, whereby the contents will be used as the destination of a transfer + TransferSrcDst, //< intended for staging storage updates, whereby the contents will be used either as a source or destination of a transfer + DefaultUsage //< default usage (use any valid usage -- unable to determine prior to usage) +}; + +// Hint for allocation requests indicating ideal caching support (if available) +enum class MemoryCaching { + InvalidCaching, //< invalid enum value + Cached, //< cached + Uncached, //< uncached + CachedCoherent, //< cached and coherent + UncachedCoherent, //< uncached but still coherent + DefaultCaching //< default caching (use any valid caching behaviour -- unable to determine prior to usage) +}; + +struct MemoryProperties { + MemoryVisibility visibility = MemoryVisibility::InvalidVisibility; + MemoryUsage usage = MemoryUsage::InvalidUsage; + MemoryCaching caching = MemoryCaching::InvalidCaching; +}; + +// Client-facing struct for exchanging memory block allocation requests +struct MemoryBlock { + void *handle = nullptr; //< client data storing native handle (managed by alloc_block_region/free_block_region) + size_t size = 0; //< allocated size (in bytes) + bool dedicated = false; //< flag indicating whether allocation is one dedicated resource (or split/shared into other resources) + MemoryProperties properties; //< properties for the allocated block +}; + +// Client-facing struct for exchanging memory region allocation requests +struct MemoryRegion { + void *handle = nullptr; //< client data storing native handle (managed by alloc_block_region/free_block_region) + size_t offset = 0; //< offset from base address in block (in bytes) + size_t size = 0; //< allocated size (in bytes) + bool dedicated = false; //< flag indicating whether allocation is one dedicated resource (or split/shared into other resources) + MemoryProperties properties; //< properties for the allocated region +}; + +// Client-facing struct for issuing memory allocation requests +struct MemoryRequest { + size_t offset = 0; //< offset from base address in block (in bytes) + size_t size = 0; //< allocated size (in bytes) + size_t alignment = 0; //< alignment constraint for address + bool dedicated = false; //< flag indicating whether allocation is one dedicated resource (or split/shared into other resources) + MemoryProperties properties; //< properties for the allocated region +}; + +class RegionAllocator; +struct BlockRegion; + +// Internal struct for block resource state +// -- Note: first field must MemoryBlock +struct BlockResource { + MemoryBlock memory; //< memory info for the allocated block + RegionAllocator *allocator = nullptr; //< designated allocator for the block + BlockRegion *regions = nullptr; //< head of linked list of memory regions + size_t reserved = 0; //< number of bytes already reserved to regions +}; + +// Internal struct for block region state +// -- Note: first field must MemoryRegion +struct BlockRegion { + MemoryRegion memory; //< memory info for the allocated region + AllocationStatus status = AllocationStatus::InvalidStatus; //< allocation status indicator + BlockRegion *next_ptr = nullptr; //< pointer to next block region in linked list + BlockRegion *prev_ptr = nullptr; //< pointer to prev block region in linked list + BlockResource *block_ptr = nullptr; //< pointer to parent block resource +}; + +// Returns an aligned byte offset to adjust the given offset based on alignment constraints +// -- Alignment must be power of two! +ALWAYS_INLINE size_t aligned_offset(size_t offset, size_t alignment) { + return (offset + (alignment - 1)) & ~(alignment - 1); +} + +// Returns a padded size to accomodate an adjusted offset due to alignment constraints +// -- Alignment must be power of two! +ALWAYS_INLINE size_t aligned_size(size_t offset, size_t size, size_t alignment) { + size_t actual_offset = aligned_offset(offset, alignment); + size_t padding = actual_offset - offset; + size_t actual_size = padding + size; + return actual_size; +} + +// Clamps the given value to be within the [min_value, max_value] range +ALWAYS_INLINE size_t clamped_size(size_t value, size_t min_value, size_t max_value) { + size_t result = (value < min_value) ? min_value : value; + return (result > max_value) ? max_value : result; +} + +// Offset the untyped pointer by the given number of bytes +ALWAYS_INLINE const void *offset_address(const void *address, size_t byte_offset) { + const uintptr_t base = reinterpret_cast(address); + return reinterpret_cast(base + byte_offset); +} + +// Offset the untyped pointer by the given number of bytes +ALWAYS_INLINE void *offset_address(void *address, size_t byte_offset) { + const uintptr_t base = reinterpret_cast(address); + return reinterpret_cast(base + byte_offset); +} + +// -- + +typedef void *(*AllocateSystemFn)(void *, size_t); +typedef void (*DeallocateSystemFn)(void *, void *); + +ALWAYS_INLINE void *native_system_malloc(void *user_context, size_t bytes) { + return malloc(bytes); +} + +ALWAYS_INLINE void native_system_free(void *user_context, void *ptr) { + free(ptr); +} + +struct SystemMemoryAllocatorFns { + AllocateSystemFn allocate = nullptr; + DeallocateSystemFn deallocate = nullptr; +}; + +struct HalideSystemAllocatorFns { + AllocateSystemFn allocate = halide_malloc; + DeallocateSystemFn deallocate = halide_free; +}; + +typedef void (*AllocateBlockFn)(void *, MemoryBlock *); +typedef void (*DeallocateBlockFn)(void *, MemoryBlock *); + +struct MemoryBlockAllocatorFns { + AllocateBlockFn allocate = nullptr; + DeallocateBlockFn deallocate = nullptr; +}; + +typedef void (*AllocateRegionFn)(void *, MemoryRegion *); +typedef void (*DeallocateRegionFn)(void *, MemoryRegion *); + +struct MemoryRegionAllocatorFns { + AllocateRegionFn allocate = nullptr; + DeallocateRegionFn deallocate = nullptr; +}; + +// -- + +} // namespace Internal +} // namespace Runtime +} // namespace Halide + +// -- + +extern "C" { + +WEAK const char *halide_memory_visibility_name(MemoryVisibility value) { + switch (value) { + case MemoryVisibility::InvalidVisibility: { + return "InvalidVisibility"; + } + case MemoryVisibility::DefaultVisibility: { + return "DefaultVisibility"; + } + case MemoryVisibility::HostOnly: { + return "HostOnly"; + } + case MemoryVisibility::DeviceOnly: { + return "DeviceOnly"; + } + case MemoryVisibility::HostToDevice: { + return "HostToDevice"; + } + case MemoryVisibility::DeviceToHost: { + return "DeviceToHost"; + } + default: { + return ""; + } + }; + return ""; +} + +WEAK const char *halide_memory_usage_name(MemoryUsage value) { + switch (value) { + case MemoryUsage::InvalidUsage: { + return "InvalidUsage"; + } + case MemoryUsage::DefaultUsage: { + return "DefaultUsage"; + } + case MemoryUsage::StaticStorage: { + return "StaticStorage"; + } + case MemoryUsage::DynamicStorage: { + return "DynamicStorage"; + } + case MemoryUsage::UniformStorage: { + return "UniformStorage"; + } + case MemoryUsage::TransferSrc: { + return "TransferSrc"; + } + case MemoryUsage::TransferDst: { + return "TransferDst"; + } + case MemoryUsage::TransferSrcDst: { + return "TransferSrcDst"; + } + default: { + return ""; + } + }; + return ""; +} + +WEAK const char *halide_memory_caching_name(MemoryCaching value) { + switch (value) { + case MemoryCaching::InvalidCaching: { + return "InvalidCaching"; + } + case MemoryCaching::DefaultCaching: { + return "DefaultCaching"; + } + case MemoryCaching::Cached: { + return "Cached"; + } + case MemoryCaching::Uncached: { + return "Uncached"; + } + case MemoryCaching::CachedCoherent: { + return "CachedCoherent"; + } + case MemoryCaching::UncachedCoherent: { + return "UncachedCoherent"; + } + default: { + return ""; + } + }; + return ""; +} + +} // extern "C" + +// -- + +#endif // HALIDE_RUNTIME_MEMORY_RESOURCES_H diff --git a/src/runtime/internal/pointer_table.h b/src/runtime/internal/pointer_table.h new file mode 100644 index 000000000000..b37a86338028 --- /dev/null +++ b/src/runtime/internal/pointer_table.h @@ -0,0 +1,370 @@ +#ifndef HALIDE_RUNTIME_POINTER_TABLE_H +#define HALIDE_RUNTIME_POINTER_TABLE_H + +#include "memory_resources.h" + +namespace Halide { +namespace Runtime { +namespace Internal { + +// Dynamically resizable array for storing untyped pointers +// -- Implementation uses memcpy/memmove for copying +// -- Customizable allocator ... default uses NativeSystemAllocator +class PointerTable { +public: + static constexpr size_t default_capacity = 32; // smallish + + PointerTable(void *user_context, size_t initial_capacity = 0, const SystemMemoryAllocatorFns &sma = default_allocator()); + PointerTable(const PointerTable &other); + ~PointerTable(); + + void initialize(void *user_context, size_t initial_capacity = 0, const SystemMemoryAllocatorFns &sma = default_allocator()); + + PointerTable &operator=(const PointerTable &other); + bool operator==(const PointerTable &other) const; + bool operator!=(const PointerTable &other) const; + + void reserve(void *user_context, size_t capacity, bool free_existing = false); + void resize(void *user_context, size_t entry_count, bool realloc = true); + + void assign(void *user_context, size_t index, const void *entry_ptr); + void insert(void *user_context, size_t index, const void *entry_ptr); + void prepend(void *user_context, const void *entry_ptr); + void append(void *user_context, const void *entry_ptr); + void remove(void *user_context, size_t index); + + void fill(void *user_context, const void **array, size_t array_size); + void insert(void *user_context, size_t index, const void **array, size_t array_size); + void replace(void *user_context, size_t index, const void **array, size_t array_size); + void prepend(void *user_context, const void **array, size_t array_size); + void append(void *user_context, const void **array, size_t array_size); + void remove(void *user_context, size_t index, size_t entry_count); + + void pop_front(void *user_context); + void pop_back(void *user_context); + void shrink_to_fit(void *user_context); + void clear(void *user_context); + void destroy(void *user_context); + + bool empty() const; + size_t size() const; + + void *operator[](size_t index); + void *operator[](size_t index) const; + + void **data(); + const void **data() const; + + void *front(); + void *back(); + + const SystemMemoryAllocatorFns ¤t_allocator() const; + static const SystemMemoryAllocatorFns &default_allocator(); + +private: + void allocate(void *user_context, size_t capacity); + + void **ptr = nullptr; + size_t count = 0; + size_t capacity = 0; + SystemMemoryAllocatorFns allocator; +}; + +PointerTable::PointerTable(void *user_context, size_t initial_capacity, const SystemMemoryAllocatorFns &sma) + : allocator(sma) { + halide_abort_if_false(user_context, allocator.allocate != nullptr); + halide_abort_if_false(user_context, allocator.deallocate != nullptr); + if (initial_capacity) { + reserve(user_context, initial_capacity); + } +} + +PointerTable::PointerTable(const PointerTable &other) + : PointerTable(nullptr, 0, other.allocator) { + if (other.capacity) { + ptr = static_cast(allocator.allocate(nullptr, other.capacity * sizeof(void *))); + capacity = other.capacity; + } + if (ptr && other.count != 0) { + count = other.count; + memcpy(this->ptr, other.ptr, count * sizeof(void *)); + } +} + +PointerTable::~PointerTable() { + destroy(nullptr); +} + +void PointerTable::destroy(void *user_context) { + halide_abort_if_false(user_context, allocator.deallocate != nullptr); + if (ptr != nullptr) { + allocator.deallocate(user_context, ptr); + } + capacity = count = 0; + ptr = nullptr; +} + +void PointerTable::initialize(void *user_context, size_t initial_capacity, const SystemMemoryAllocatorFns &sma) { + allocator = sma; + capacity = count = 0; + ptr = nullptr; + if (initial_capacity) { + reserve(user_context, initial_capacity); + } +} + +PointerTable &PointerTable::operator=(const PointerTable &other) { + if (&other != this) { + resize(nullptr, other.count); + if (count != 0 && other.ptr != nullptr) { + memcpy(ptr, other.ptr, count * sizeof(void *)); + } + } + return *this; +} + +bool PointerTable::operator==(const PointerTable &other) const { + if (count != other.count) { + return false; + } + return memcmp(this->ptr, other.ptr, this->size() * sizeof(void *)) == 0; +} + +bool PointerTable::operator!=(const PointerTable &other) const { + return !(*this == other); +} + +void PointerTable::fill(void *user_context, const void **array, size_t array_size) { + if (array_size != 0) { + resize(user_context, array_size); + memcpy(this->ptr, array, array_size * sizeof(void *)); + count = array_size; + } +} + +void PointerTable::assign(void *user_context, size_t index, const void *entry_ptr) { + halide_debug_assert(user_context, index < count); + ptr[index] = const_cast(entry_ptr); +} + +void PointerTable::prepend(void *user_context, const void *entry_ptr) { + insert(user_context, 0, &entry_ptr, 1); +} + +void PointerTable::append(void *user_context, const void *entry_ptr) { + append(user_context, &entry_ptr, 1); +} + +void PointerTable::pop_front(void *user_context) { + halide_debug_assert(user_context, count > 0); + remove(user_context, 0); +} + +void PointerTable::pop_back(void *user_context) { + halide_debug_assert(user_context, count > 0); + resize(user_context, size() - 1); +} + +void PointerTable::clear(void *user_context) { + resize(user_context, 0); +} + +void PointerTable::reserve(void *user_context, size_t new_capacity, bool free_existing) { + new_capacity = max(new_capacity, count); + if ((new_capacity < capacity) && !free_existing) { + new_capacity = capacity; + } + allocate(user_context, new_capacity); +} + +void PointerTable::resize(void *user_context, size_t entry_count, bool realloc) { + size_t current_size = capacity; + size_t requested_size = entry_count; + size_t minimum_size = default_capacity; + size_t actual_size = current_size; + count = requested_size; + +#ifdef DEBUG_RUNTIME + debug(user_context) << "PointerTable: Resize (" + << "requested_size=" << (int32_t)requested_size << " " + << "current_size=" << (int32_t)current_size << " " + << "minimum_size=" << (int32_t)minimum_size << " " + << "sizeof(void*)=" << (int32_t)sizeof(void *) << " " + << "realloc=" << (realloc ? "true" : "false") << ")...\n"; +#endif + + // increase capacity upto 1.5x existing (or at least min_capacity) + if (requested_size > current_size) { + actual_size = max(requested_size, max(current_size * 3 / 2, minimum_size)); + } else if (!realloc) { + return; + } + + allocate(user_context, actual_size); +} + +void PointerTable::shrink_to_fit(void *user_context) { + if (capacity > count) { + void *new_ptr = nullptr; + if (count > 0) { + size_t bytes = count * sizeof(void *); + new_ptr = allocator.allocate(user_context, bytes); + memcpy(new_ptr, ptr, bytes); + } + allocator.deallocate(user_context, ptr); + capacity = count; + ptr = static_cast(new_ptr); + } +} + +void PointerTable::insert(void *user_context, size_t index, const void *entry_ptr) { + const void *addr = reinterpret_cast(entry_ptr); + insert(user_context, index, &addr, 1); +} + +void PointerTable::remove(void *user_context, size_t index) { + remove(user_context, index, 1); +} + +void PointerTable::remove(void *user_context, size_t index, size_t entry_count) { + halide_debug_assert(user_context, index < count); + const size_t last_index = size(); + if (index < (last_index - entry_count)) { + size_t dst_offset = index * sizeof(void *); + size_t src_offset = (index + entry_count) * sizeof(void *); + size_t bytes = (last_index - index - entry_count) * sizeof(void *); + +#ifdef DEBUG_RUNTIME + debug(user_context) << "PointerTable: Remove (" + << "index=" << (int32_t)index << " " + << "entry_count=" << (int32_t)entry_count << " " + << "last_index=" << (int32_t)last_index << " " + << "src_offset=" << (int32_t)src_offset << " " + << "dst_offset=" << (int32_t)dst_offset << " " + << "bytes=" << (int32_t)bytes << ")...\n"; +#endif + memmove(ptr + dst_offset, ptr + src_offset, bytes); + } + resize(user_context, last_index - entry_count); +} + +void PointerTable::replace(void *user_context, size_t index, const void **array, size_t array_size) { + halide_debug_assert(user_context, index < count); + size_t remaining = count - index; + size_t copy_count = min(remaining, array_size); + +#ifdef DEBUG_RUNTIME + + debug(user_context) << "PointerTable: Replace (" + << "index=" << (int32_t)index << " " + << "array_size=" << (int32_t)array_size << " " + << "remaining=" << (int32_t)remaining << " " + << "copy_count=" << (int32_t)copy_count << " " + << "capacity=" << (int32_t)capacity << ")...\n"; +#endif + + halide_debug_assert(user_context, remaining > 0); + memcpy(ptr + index, array, copy_count * sizeof(void *)); + count = max(count, index + copy_count); +} + +void PointerTable::insert(void *user_context, size_t index, const void **array, size_t array_size) { + halide_debug_assert(user_context, index <= count); + const size_t last_index = size(); + resize(user_context, last_index + array_size); + if (index < last_index) { + size_t src_offset = index * sizeof(void *); + size_t dst_offset = (index + array_size) * sizeof(void *); + size_t bytes = (last_index - index) * sizeof(void *); + memmove(ptr + dst_offset, ptr + src_offset, bytes); + } + replace(user_context, index, array, array_size); +} + +void PointerTable::prepend(void *user_context, const void **array, size_t array_size) { + insert(user_context, 0, array, array_size); +} + +void PointerTable::append(void *user_context, const void **array, size_t array_size) { + const size_t last_index = size(); + insert(user_context, last_index, array, array_size); +} + +bool PointerTable::empty() const { + return count == 0; +} + +size_t PointerTable::size() const { + return count; +} + +void *PointerTable::operator[](size_t index) { + halide_debug_assert(nullptr, index < capacity); + return ptr[index]; +} + +void *PointerTable::operator[](size_t index) const { + halide_debug_assert(nullptr, index < capacity); + return ptr[index]; +} + +void **PointerTable::data() { + return ptr; +} + +void *PointerTable::front() { + halide_debug_assert(nullptr, count > 0); + return ptr[0]; +} + +void *PointerTable::back() { + halide_debug_assert(nullptr, count > 0); + size_t index = count - 1; + return ptr[index]; +} + +const void **PointerTable::data() const { + return const_cast(ptr); +} + +void PointerTable::allocate(void *user_context, size_t new_capacity) { + if (new_capacity != capacity) { + halide_abort_if_false(user_context, allocator.allocate != nullptr); + size_t bytes = new_capacity * sizeof(void *); + +#ifdef DEBUG_RUNTIME + debug(user_context) << "PointerTable: Allocating (bytes=" << (int32_t)bytes << " allocator=" << (void *)allocator.allocate << ")...\n"; +#endif + + void *new_ptr = bytes ? allocator.allocate(user_context, bytes) : nullptr; + if (count != 0 && ptr != nullptr && new_ptr != nullptr) { + memcpy(new_ptr, ptr, count * sizeof(void *)); + } + if (ptr != nullptr) { + halide_abort_if_false(user_context, allocator.deallocate != nullptr); + allocator.deallocate(user_context, ptr); + } + capacity = new_capacity; + ptr = static_cast(new_ptr); + } +} + +const SystemMemoryAllocatorFns & +PointerTable::current_allocator() const { + return this->allocator; +} + +const SystemMemoryAllocatorFns & +PointerTable::default_allocator() { + static SystemMemoryAllocatorFns native_allocator = { + native_system_malloc, native_system_free}; + return native_allocator; +} + +// -- + +} // namespace Internal +} // namespace Runtime +} // namespace Halide + +#endif // HALIDE_RUNTIME_POINTER_TABLE_H diff --git a/src/runtime/internal/region_allocator.h b/src/runtime/internal/region_allocator.h new file mode 100644 index 000000000000..5deba8c644fc --- /dev/null +++ b/src/runtime/internal/region_allocator.h @@ -0,0 +1,467 @@ +#ifndef HALIDE_RUNTIME_REGION_ALLOCATOR_H +#define HALIDE_RUNTIME_REGION_ALLOCATOR_H + +#include "memory_arena.h" +#include "memory_resources.h" + +namespace Halide { +namespace Runtime { +namespace Internal { + +// -- + +/** Allocator class interface for sub-allocating a contiguous + * memory block into smaller regions of memory. This class only + * manages the address creation for the regions -- allocation + * callback functions are used to request the memory from the + * necessary system or API calls. This class is intended to be + * used inside of a higher level memory management class that + * provides thread safety, policy management and API + * integration for a specific runtime API (eg Vulkan, OpenCL, etc) + */ +class RegionAllocator { +public: + // disable copy constructors and assignment + RegionAllocator(const RegionAllocator &) = delete; + RegionAllocator &operator=(const RegionAllocator &) = delete; + + // disable non-factory based construction + RegionAllocator() = delete; + ~RegionAllocator() = delete; + + // Allocators for the different types of memory we need to allocate + struct MemoryAllocators { + SystemMemoryAllocatorFns system; + MemoryRegionAllocatorFns region; + }; + + // Factory methods for creation / destruction + static RegionAllocator *create(void *user_context, BlockResource *block, const MemoryAllocators &ma); + static void destroy(void *user_context, RegionAllocator *region_allocator); + + // Returns the allocator class instance for the given allocation (or nullptr) + static RegionAllocator *find_allocator(void *user_context, MemoryRegion *memory_region); + + // Public interface methods + MemoryRegion *reserve(void *user_context, const MemoryRequest &request); + void reclaim(void *user_context, MemoryRegion *memory_region); + bool collect(void *user_context); //< returns true if any blocks were removed + void release(void *user_context); + void destroy(void *user_context); + + // Returns the currently managed block resource + BlockResource *block_resource() const; + +private: + // Initializes a new instance + void initialize(void *user_context, BlockResource *block, const MemoryAllocators &ma); + + // Search through allocated block regions (Best-Fit) + BlockRegion *find_block_region(void *user_context, const MemoryRequest &request); + + // Returns true if neighbouring block regions to the given region can be coalesced into one + bool can_coalesce(BlockRegion *region); + + // Merges available neighbouring block regions into the given region + BlockRegion *coalesce_block_regions(void *user_context, BlockRegion *region); + + // Returns true if the given region can be split to accomadate the given size + bool can_split(BlockRegion *region, size_t size); + + // Splits the given block region into a smaller region to accomadate the given size, followed by empty space for the remaining + BlockRegion *split_block_region(void *user_context, BlockRegion *region, size_t size, size_t alignment); + + // Creates a new block region and adds it to the region list + BlockRegion *create_block_region(void *user_context, const MemoryProperties &properties, size_t offset, size_t size, bool dedicated); + + // Creates a new block region and adds it to the region list + void destroy_block_region(void *user_context, BlockRegion *region); + + // Invokes the allocation callback to allocate memory for the block region + void alloc_block_region(void *user_context, BlockRegion *region); + + // Releases a block region and leaves it in the list for further allocations + void release_block_region(void *user_context, BlockRegion *region); + + // Invokes the deallocation callback to free memory for the block region + void free_block_region(void *user_context, BlockRegion *region); + + // Returns true if the given block region is compatible with the given properties + bool is_compatible_block_region(const BlockRegion *region, const MemoryProperties &properties) const; + + BlockResource *block = nullptr; + MemoryArena *arena = nullptr; + MemoryAllocators allocators; +}; + +RegionAllocator *RegionAllocator::create(void *user_context, BlockResource *block_resource, const MemoryAllocators &allocators) { + halide_abort_if_false(user_context, allocators.system.allocate != nullptr); + RegionAllocator *result = reinterpret_cast( + allocators.system.allocate(user_context, sizeof(RegionAllocator))); + + if (result == nullptr) { + halide_error(user_context, "RegionAllocator: Failed to create instance! Out of memory!\n"); + return nullptr; + } + + result->initialize(user_context, block_resource, allocators); + return result; +} + +void RegionAllocator::destroy(void *user_context, RegionAllocator *instance) { + halide_abort_if_false(user_context, instance != nullptr); + const MemoryAllocators &allocators = instance->allocators; + instance->destroy(user_context); + halide_abort_if_false(user_context, allocators.system.deallocate != nullptr); + allocators.system.deallocate(user_context, instance); +} + +void RegionAllocator::initialize(void *user_context, BlockResource *mb, const MemoryAllocators &ma) { + block = mb; + allocators = ma; + arena = MemoryArena::create(user_context, {sizeof(BlockRegion), MemoryArena::default_capacity, 0}, allocators.system); + halide_abort_if_false(user_context, arena != nullptr); + block->allocator = this; + block->regions = create_block_region( + user_context, + block->memory.properties, + 0, block->memory.size, + block->memory.dedicated); +} + +MemoryRegion *RegionAllocator::reserve(void *user_context, const MemoryRequest &request) { + halide_abort_if_false(user_context, request.size > 0); + size_t remaining = block->memory.size - block->reserved; + if (remaining < request.size) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "RegionAllocator: Unable to reserve more memory from block " + << "-- requested size (" << (int32_t)(request.size) << " bytes) " + << "greater than available (" << (int32_t)(remaining) << " bytes)!\n"; +#endif + return nullptr; + } + + BlockRegion *block_region = find_block_region(user_context, request); + if (block_region == nullptr) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "RegionAllocator: Failed to locate region for requested size (" + << (int32_t)(request.size) << " bytes)!\n"; +#endif + return nullptr; + } + + if (can_split(block_region, request.size)) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "RegionAllocator: Splitting region of size ( " << (int32_t)(block_region->memory.size) << ") " + << "to accomodate requested size (" << (int32_t)(request.size) << " bytes)!\n"; +#endif + split_block_region(user_context, block_region, request.size, request.alignment); + } + + alloc_block_region(user_context, block_region); + return reinterpret_cast(block_region); +} + +void RegionAllocator::reclaim(void *user_context, MemoryRegion *memory_region) { + BlockRegion *block_region = reinterpret_cast(memory_region); + halide_abort_if_false(user_context, block_region != nullptr); + halide_abort_if_false(user_context, block_region->block_ptr == block); + free_block_region(user_context, block_region); + if (can_coalesce(block_region)) { + block_region = coalesce_block_regions(user_context, block_region); + } +} + +RegionAllocator *RegionAllocator::find_allocator(void *user_context, MemoryRegion *memory_region) { + BlockRegion *block_region = reinterpret_cast(memory_region); + halide_abort_if_false(user_context, block_region != nullptr); + halide_abort_if_false(user_context, block_region->block_ptr != nullptr); + return block_region->block_ptr->allocator; +} + +BlockRegion *RegionAllocator::find_block_region(void *user_context, const MemoryRequest &request) { + BlockRegion *result = nullptr; + for (BlockRegion *block_region = block->regions; block_region != nullptr; block_region = block_region->next_ptr) { + + if (block_region->status != AllocationStatus::Available) { + continue; + } + + // skip incompatible block regions for this request + if (!is_compatible_block_region(block_region, request.properties)) { + continue; + } + + // is the requested size larger than the current region? + if (request.size > block_region->memory.size) { + continue; + } + + size_t actual_size = aligned_size(block_region->memory.offset, request.size, request.alignment); + + // is the adjusted size larger than the current region? + if (actual_size > block_region->memory.size) { + continue; + } + + // will the adjusted size fit within the remaining unallocated space? + if ((actual_size + block->reserved) < block->memory.size) { + result = block_region; // best-fit! + break; + } + } + return result; +} + +bool RegionAllocator::can_coalesce(BlockRegion *block_region) { + if (block_region == nullptr) { + return false; + } + if (block_region->prev_ptr && (block_region->prev_ptr->status == AllocationStatus::Available)) { + return true; + } + if (block_region->next_ptr && (block_region->next_ptr->status == AllocationStatus::Available)) { + return true; + } + return false; +} + +BlockRegion *RegionAllocator::coalesce_block_regions(void *user_context, BlockRegion *block_region) { + if (block_region->prev_ptr && (block_region->prev_ptr->status == AllocationStatus::Available)) { + BlockRegion *prev_region = block_region->prev_ptr; + +#ifdef DEBUG_RUNTIME + debug(user_context) << "RegionAllocator: Coalescing " + << "previous region (offset=" << (int32_t)prev_region->memory.offset << " size=" << (int32_t)(prev_region->memory.size) << " bytes) " + << "into current region (offset=" << (int32_t)block_region->memory.offset << " size=" << (int32_t)(block_region->memory.size) << " bytes)\n!"; +#endif + + prev_region->next_ptr = block_region->next_ptr; + if (block_region->next_ptr) { + block_region->next_ptr->prev_ptr = prev_region; + } + prev_region->memory.size += block_region->memory.size; + destroy_block_region(user_context, block_region); + block_region = prev_region; + } + + if (block_region->next_ptr && (block_region->next_ptr->status == AllocationStatus::Available)) { + BlockRegion *next_region = block_region->next_ptr; + +#ifdef DEBUG_RUNTIME + debug(user_context) << "RegionAllocator: Coalescing " + << "next region (offset=" << (int32_t)next_region->memory.offset << " size=" << (int32_t)(next_region->memory.size) << " bytes) " + << "into current region (offset=" << (int32_t)block_region->memory.offset << " size=" << (int32_t)(block_region->memory.size) << " bytes)!\n"; +#endif + + if (next_region->next_ptr) { + next_region->next_ptr->prev_ptr = block_region; + } + block_region->next_ptr = next_region->next_ptr; + block_region->memory.size += next_region->memory.size; + destroy_block_region(user_context, next_region); + } + + return block_region; +} + +bool RegionAllocator::can_split(BlockRegion *block_region, size_t size) { + return (block_region && (block_region->memory.size > size)); +} + +BlockRegion *RegionAllocator::split_block_region(void *user_context, BlockRegion *block_region, size_t size, size_t alignment) { + size_t adjusted_size = aligned_size(block_region->memory.offset, size, alignment); + size_t adjusted_offset = aligned_offset(block_region->memory.offset, alignment); + + size_t empty_offset = adjusted_offset + size; + size_t empty_size = block_region->memory.size - adjusted_size; + +#ifdef DEBUG_RUNTIME + debug(user_context) << "RegionAllocator: Splitting " + << "current region (offset=" << (int32_t)block_region->memory.offset << " size=" << (int32_t)(block_region->memory.size) << " bytes) " + << "to create empty region (offset=" << (int32_t)empty_offset << " size=" << (int32_t)(empty_size) << " bytes)!\n"; +#endif + + BlockRegion *next_region = block_region->next_ptr; + BlockRegion *empty_region = create_block_region(user_context, + block_region->memory.properties, + empty_offset, empty_size, + block_region->memory.dedicated); + halide_abort_if_false(user_context, empty_region != nullptr); + + empty_region->next_ptr = next_region; + if (next_region) { + next_region->prev_ptr = empty_region; + } + block_region->next_ptr = empty_region; + block_region->memory.size = size; + return empty_region; +} + +BlockRegion *RegionAllocator::create_block_region(void *user_context, const MemoryProperties &properties, size_t offset, size_t size, bool dedicated) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "RegionAllocator: Creating block region (" + << "user_context=" << (void *)(user_context) << " " + << "offset=" << (uint32_t)offset << " " + << "size=" << (uint32_t)size << " " + << "dedicated=" << (dedicated ? "true" : "false") << " " + << "usage=" << halide_memory_usage_name(properties.usage) << " " + << "caching=" << halide_memory_caching_name(properties.caching) << " " + << "visibility=" << halide_memory_visibility_name(properties.visibility) << ") ...\n"; +#endif + + BlockRegion *block_region = static_cast(arena->reserve(user_context, true)); + + if (block_region == nullptr) { + error(user_context) << "RegionAllocator: Failed to allocate new block region!\n"; + return nullptr; + } + +#ifdef DEBUG_RUNTIME + debug(user_context) << "RegionAllocator: Added block region (" + << "user_context=" << (void *)(user_context) << " " + << "block_region=" << (void *)(block_region) << ") ...\n"; +#endif + + block_region->memory.offset = offset; + block_region->memory.size = size; + block_region->memory.properties = properties; + block_region->memory.dedicated = dedicated; + block_region->status = AllocationStatus::Available; + block_region->block_ptr = block; + return block_region; +} + +void RegionAllocator::release_block_region(void *user_context, BlockRegion *block_region) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "RegionAllocator: Releasing block region (" + << "user_context=" << (void *)(user_context) << " " + << "block_region=" << (void *)(block_region) << ") ...\n"; +#endif + block_region->status = AllocationStatus::Available; +} + +void RegionAllocator::destroy_block_region(void *user_context, BlockRegion *block_region) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "RegionAllocator: Destroying block region (" + << "user_context=" << (void *)(user_context) << " " + << "block_region=" << (void *)(block_region) << ") ...\n"; +#endif + + free_block_region(user_context, block_region); + arena->reclaim(user_context, block_region); +} + +void RegionAllocator::alloc_block_region(void *user_context, BlockRegion *block_region) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "RegionAllocator: Allocating region (size=" << (int32_t)(block_region->memory.size) << ", offset=" << (int32_t)block_region->memory.offset << ")!\n"; +#endif + halide_abort_if_false(user_context, allocators.region.allocate != nullptr); + halide_abort_if_false(user_context, block_region->status == AllocationStatus::Available); + MemoryRegion *memory_region = &(block_region->memory); + allocators.region.allocate(user_context, memory_region); + block_region->status = block_region->memory.dedicated ? AllocationStatus::Dedicated : AllocationStatus::InUse; + block->reserved += block_region->memory.size; +} + +void RegionAllocator::free_block_region(void *user_context, BlockRegion *block_region) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "RegionAllocator: Freeing block region (" + << "user_context=" << (void *)(user_context) << " " + << "block_region=" << (void *)(block_region) << ") ...\n"; +#endif + if ((block_region->status == AllocationStatus::InUse) || + (block_region->status == AllocationStatus::Dedicated)) { + debug(user_context) << "RegionAllocator: Deallocating region (size=" << (int32_t)(block_region->memory.size) << ", offset=" << (int32_t)block_region->memory.offset << ")!\n"; + halide_abort_if_false(user_context, allocators.region.deallocate != nullptr); + MemoryRegion *memory_region = &(block_region->memory); + allocators.region.deallocate(user_context, memory_region); + block->reserved -= block_region->memory.size; + block_region->memory.size = 0; + } + block_region->status = AllocationStatus::Available; +} + +void RegionAllocator::release(void *user_context) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "RegionAllocator: Releasing all regions (" + << "user_context=" << (void *)(user_context) << ") ...\n"; +#endif + for (BlockRegion *block_region = block->regions; block_region != nullptr; block_region = block_region->next_ptr) { + release_block_region(user_context, block_region); + } +} + +bool RegionAllocator::collect(void *user_context) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "RegionAllocator: Collecting free block regions (" + << "user_context=" << (void *)(user_context) << ") ...\n"; +#endif + bool result = false; + for (BlockRegion *block_region = block->regions; block_region != nullptr; block_region = block_region->next_ptr) { + if (block_region->status == AllocationStatus::Available) { + if (can_coalesce(block_region)) { + block_region = coalesce_block_regions(user_context, block_region); + result = true; + } + } + } + return result; +} + +void RegionAllocator::destroy(void *user_context) { +#ifdef DEBUG_RUNTIME + debug(user_context) << "RegionAllocator: Destroying all block regions (" + << "user_context=" << (void *)(user_context) << ") ...\n"; +#endif + for (BlockRegion *block_region = block->regions; block_region != nullptr;) { + + if (block_region->next_ptr == nullptr) { + destroy_block_region(user_context, block_region); + block_region = nullptr; + } else { + BlockRegion *prev_region = block_region; + block_region = block_region->next_ptr; + destroy_block_region(user_context, prev_region); + } + } + block->reserved = 0; + block->regions = nullptr; + block->allocator = nullptr; + MemoryArena::destroy(user_context, arena); + arena = nullptr; +} + +bool RegionAllocator::is_compatible_block_region(const BlockRegion *block_region, const MemoryProperties &properties) const { + if (properties.caching != MemoryCaching::DefaultCaching) { + if (properties.caching != block_region->memory.properties.caching) { + return false; + } + } + + if (properties.visibility != MemoryVisibility::DefaultVisibility) { + if (properties.visibility != block_region->memory.properties.visibility) { + return false; + } + } + + if (properties.usage != MemoryUsage::DefaultUsage) { + if (properties.usage != block_region->memory.properties.usage) { + return false; + } + } + + return true; +} + +BlockResource *RegionAllocator::block_resource() const { + return block; +} + +// -- + +} // namespace Internal +} // namespace Runtime +} // namespace Halide + +#endif // HALIDE_RUNTIME_REGION_ALLOCATOR_H diff --git a/src/runtime/internal/string_storage.h b/src/runtime/internal/string_storage.h new file mode 100644 index 000000000000..ac7dac69215c --- /dev/null +++ b/src/runtime/internal/string_storage.h @@ -0,0 +1,304 @@ +#ifndef HALIDE_RUNTIME_STRING_STORAGE_H +#define HALIDE_RUNTIME_STRING_STORAGE_H + +#include "block_storage.h" + +namespace Halide { +namespace Runtime { +namespace Internal { + +// Static utility functions for dealing with string data +struct StringUtils { + static bool is_empty(const char *str) { + if (str == nullptr) { + return true; + } + if (str[0] == '\0') { + return true; + } + return false; + } + + // count the number of delimited string tokens + static size_t count_tokens(const char *str, const char *delim) { + if (StringUtils::is_empty(str)) { + return 0; + } + if (StringUtils::is_empty(delim)) { + return 1; + } // no delim ... string is one token + + size_t count = 0; + const char *ptr = str; + size_t delim_length = strlen(delim); + while (!StringUtils::is_empty(ptr)) { + const char *next_delim = strstr(ptr, delim); + ptr = (next_delim != nullptr) ? (next_delim + delim_length) : nullptr; + ++count; + } + return count; + } + + // retuns true if s1 contains s2 (within n characters) + static bool contains(const char *s1, const char *s2, size_t n) { + if (is_empty(s2)) { + return true; + } // s2 is empty ... return true to match strstr + char starts_with = *s2; + for (size_t length = strlen(s2); length <= n; n--, s1++) { + if (*s1 == starts_with) { + for (size_t i = 1; i <= length; i++) { + if (i == length) { + return true; + } + if (s1[i] != s2[i]) { + break; + } + } + } + } + return false; + } + + static size_t count_length(const char *str, size_t max_chars) { + const char *ptr = str; + while (!StringUtils::is_empty(ptr) && ((size_t(ptr - str)) < max_chars)) { + ++ptr; + } + return size_t(ptr - str); + } +}; + +// -- +// Storage class for handling c-string data (based on block storage) +// -- Intended for building and maintaining string data w/8-bit chars +// +class StringStorage { +public: + StringStorage(void *user_context = nullptr, uint32_t capacity = 0, const SystemMemoryAllocatorFns &sma = default_allocator()); + StringStorage(const StringStorage &other) = default; + ~StringStorage(); + + // Factory methods for creation / destruction + static StringStorage *create(void *user_context, const SystemMemoryAllocatorFns &ma); + static void destroy(void *user_context, StringStorage *string_storage); + + void initialize(void *user_context, uint32_t capacity = 0, const SystemMemoryAllocatorFns &sma = default_allocator()); + void destroy(void *user_context); + + StringStorage &operator=(const StringStorage &other); + bool operator==(const StringStorage &other) const; + bool operator!=(const StringStorage &other) const; + + bool contains(const char *str) const; + bool contains(const StringStorage &other) const; + + void reserve(void *user_context, size_t length); + void assign(void *user_context, char ch); + void assign(void *user_context, const char *str, size_t length = 0); // if length is zero, strlen is used + void append(void *user_context, char ch); + void append(void *user_context, const char *str, size_t length = 0); // if length is zero, strlen is used + void prepend(void *user_context, char ch); + void prepend(void *user_context, const char *str, size_t length = 0); // if length is zero, strlen is used + void clear(void *user_context); + void terminate(void *user_context, size_t length); + + size_t length() const; + const char *data() const; + + const SystemMemoryAllocatorFns ¤t_allocator() const; + static const SystemMemoryAllocatorFns &default_allocator(); + +private: + BlockStorage contents; +}; + +StringStorage::StringStorage(void *user_context, uint32_t capacity, const SystemMemoryAllocatorFns &sma) + : contents(user_context, {sizeof(char), 32, 32}, sma) { + if (capacity) { + contents.reserve(user_context, capacity); + } +} + +StringStorage::~StringStorage() { + destroy(nullptr); +} + +StringStorage *StringStorage::create(void *user_context, const SystemMemoryAllocatorFns &system_allocator) { + halide_abort_if_false(user_context, system_allocator.allocate != nullptr); + StringStorage *result = reinterpret_cast( + system_allocator.allocate(user_context, sizeof(StringStorage))); + + if (result == nullptr) { + halide_error(user_context, "StringStorage: Failed to create instance! Out of memory!\n"); + return nullptr; + } + + result->initialize(user_context, 32, system_allocator); + return result; +} + +void StringStorage::destroy(void *user_context, StringStorage *instance) { + halide_abort_if_false(user_context, instance != nullptr); + const SystemMemoryAllocatorFns &system_allocator = instance->current_allocator(); + instance->destroy(user_context); + halide_abort_if_false(user_context, system_allocator.deallocate != nullptr); + system_allocator.deallocate(user_context, instance); +} + +StringStorage &StringStorage::operator=(const StringStorage &other) { + if (&other != this) { + assign(nullptr, other.data(), other.length()); + } + return *this; +} + +bool StringStorage::contains(const char *str) const { + if (contents.empty()) { + return false; + } + const char *this_str = static_cast(contents.data()); + return StringUtils::contains(this_str, str, contents.size()); +} + +bool StringStorage::contains(const StringStorage &other) const { + if (contents.empty()) { + return false; + } + if (other.contents.empty()) { + return false; + } + const char *this_str = static_cast(contents.data()); + const char *other_str = static_cast(other.contents.data()); + return StringUtils::contains(this_str, other_str, contents.size()); +} + +bool StringStorage::operator==(const StringStorage &other) const { + if (contents.size() != other.contents.size()) { + return false; + } + const char *this_str = static_cast(contents.data()); + const char *other_str = static_cast(other.contents.data()); + return strncmp(this_str, other_str, contents.size()) == 0; +} + +bool StringStorage::operator!=(const StringStorage &other) const { + return !(*this == other); +} + +void StringStorage::reserve(void *user_context, size_t length) { + contents.reserve(user_context, length + 1); // leave room for termination + contents.resize(user_context, length, false); + terminate(user_context, length); +} + +void StringStorage::assign(void *user_context, char ch) { + reserve(user_context, 1); + char *ptr = static_cast(contents[0]); + (*ptr) = ch; + terminate(user_context, 1); +} + +void StringStorage::assign(void *user_context, const char *str, size_t length) { + if (StringUtils::is_empty(str)) { + return; + } + if (length == 0) { + length = strlen(str); + } + reserve(user_context, length); + contents.replace(user_context, 0, str, length); + terminate(user_context, length); +} + +void StringStorage::append(void *user_context, const char *str, size_t length) { + if (StringUtils::is_empty(str)) { + return; + } + if (length == 0) { + length = strlen(str); + } + const size_t old_length = StringUtils::count_length(data(), contents.size()); + size_t new_length = old_length + length; + reserve(user_context, new_length); + contents.insert(user_context, old_length, str, length); + terminate(user_context, new_length); +} + +void StringStorage::append(void *user_context, char ch) { + const size_t old_length = StringUtils::count_length(data(), contents.size()); + size_t new_length = old_length + 1; + reserve(user_context, new_length); + contents.insert(user_context, old_length, &ch, 1); + terminate(user_context, new_length); +} + +void StringStorage::prepend(void *user_context, const char *str, size_t length) { + if (StringUtils::is_empty(str)) { + return; + } + if (length == 0) { + length = strlen(str); + } + const size_t old_length = StringUtils::count_length(data(), contents.size()); + size_t new_length = old_length + length; + reserve(user_context, new_length); + contents.insert(user_context, 0, str, length); + terminate(user_context, new_length); +} + +void StringStorage::prepend(void *user_context, char ch) { + const size_t old_length = StringUtils::count_length(data(), contents.size()); + size_t new_length = old_length + 1; + reserve(user_context, new_length); + contents.prepend(user_context, &ch); + terminate(user_context, new_length); +} + +void StringStorage::terminate(void *user_context, size_t length) { + if (contents.data() && (length < contents.size())) { + char *end_ptr = static_cast(contents[length]); + (*end_ptr) = '\0'; + } +} + +void StringStorage::clear(void *user_context) { + contents.clear(user_context); + terminate(user_context, 0); +} + +void StringStorage::initialize(void *user_context, uint32_t capacity, const SystemMemoryAllocatorFns &sma) { + contents.initialize(user_context, {sizeof(char), 32, 32}, sma); + reserve(user_context, capacity); + terminate(user_context, 0); +} + +void StringStorage::destroy(void *user_context) { + contents.destroy(user_context); +} + +size_t StringStorage::length() const { + return StringUtils::count_length(data(), contents.size()); +} + +const char *StringStorage::data() const { + return static_cast(contents.data()); +} + +const SystemMemoryAllocatorFns & +StringStorage::current_allocator() const { + return contents.current_allocator(); +} + +const SystemMemoryAllocatorFns & +StringStorage::default_allocator() { + return BlockStorage::default_allocator(); +} + +// -- + +} // namespace Internal +} // namespace Runtime +} // namespace Halide + +#endif // HALIDE_RUNTIME_STRING_STORAGE_H diff --git a/src/runtime/internal/string_table.h b/src/runtime/internal/string_table.h new file mode 100644 index 000000000000..7fa31eccb414 --- /dev/null +++ b/src/runtime/internal/string_table.h @@ -0,0 +1,213 @@ +#ifndef HALIDE_RUNTIME_STRING_TABLE_H +#define HALIDE_RUNTIME_STRING_TABLE_H + +#include "block_storage.h" +#include "pointer_table.h" +#include "string_storage.h" + +namespace Halide { +namespace Runtime { +namespace Internal { + +// Storage class for an array of strings (based on block storage) +// -- Intended for building and maintaining tables of strings +class StringTable { +public: + // Disable copy constructors + StringTable(const StringTable &) = delete; + StringTable &operator=(const StringTable &) = delete; + + StringTable(const SystemMemoryAllocatorFns &allocator = StringStorage::default_allocator()); + StringTable(void *user_context, size_t capacity, const SystemMemoryAllocatorFns &allocator = StringStorage::default_allocator()); + StringTable(void *user_context, const char **array, size_t count, const SystemMemoryAllocatorFns &allocator = StringStorage::default_allocator()); + ~StringTable(); + + void resize(void *user_context, size_t capacity); + void destroy(void *user_context); + void clear(void *user_context); + + // fills the contents of the table (copies strings from given array) + void fill(void *user_context, const char **array, size_t coun); + + // assign the entry at given index the given string + void assign(void *user_context, size_t index, const char *str, size_t length = 0); // if length is zero, strlen is used + + // appends the given string to the end of the table + void append(void *user_context, const char *str, size_t length = 0); // if length is zero, strlen is used + + // prepend the given string to the end of the table + void prepend(void *user_context, const char *str, size_t length = 0); // if length is zero, strlen is used + + // parses the given c-string based on given delimiter, stores each substring in the resulting table + size_t parse(void *user_context, const char *str, const char *delim); + + // index-based access operator + const char *operator[](size_t index) const; + + // returns the raw string table pointer + const char **data() const; + + // scans the table for existance of the given string within any entry (linear scan w/string compare!) + bool contains(const char *str) const; + + size_t size() const { + return contents.size(); + } + +private: + PointerTable contents; //< owns string data + PointerTable pointers; //< pointers to raw string data +}; + +// -- + +StringTable::StringTable(const SystemMemoryAllocatorFns &sma) + : contents(nullptr, 0, sma), + pointers(nullptr, 0, sma) { + // EMPTY! +} + +StringTable::StringTable(void *user_context, size_t capacity, const SystemMemoryAllocatorFns &sma) + : contents(user_context, capacity, sma), + pointers(user_context, capacity, sma) { + if (capacity) { + resize(user_context, capacity); + } +} + +StringTable::StringTable(void *user_context, const char **array, size_t count, const SystemMemoryAllocatorFns &sma) + : contents(user_context, count, sma), + pointers(user_context, count, sma) { + fill(user_context, array, count); +} + +StringTable::~StringTable() { + destroy(nullptr); +} + +void StringTable::resize(void *user_context, size_t capacity) { + pointers.resize(user_context, capacity); + while (contents.size() < capacity) { + StringStorage *storage_ptr = StringStorage::create(user_context, contents.current_allocator()); + contents.append(user_context, storage_ptr); + } +} + +void StringTable::clear(void *user_context) { + for (size_t n = 0; n < contents.size(); ++n) { + StringStorage *storage_ptr = static_cast(contents[n]); + StringStorage::destroy(user_context, storage_ptr); + contents.assign(user_context, n, nullptr); + } + contents.clear(user_context); + pointers.clear(user_context); +} + +void StringTable::destroy(void *user_context) { + for (size_t n = 0; n < contents.size(); ++n) { + StringStorage *storage_ptr = static_cast(contents[n]); + StringStorage::destroy(user_context, storage_ptr); + contents.assign(user_context, n, nullptr); + } + contents.destroy(user_context); + pointers.destroy(user_context); +} + +const char *StringTable::operator[](size_t index) const { + if (index < pointers.size()) { + return static_cast(pointers[index]); + } + return nullptr; +} + +void StringTable::fill(void *user_context, const char **array, size_t count) { + resize(user_context, count); + for (size_t n = 0; n < count && n < contents.size(); ++n) { + StringStorage *storage_ptr = static_cast(contents[n]); + storage_ptr->assign(user_context, array[n]); + pointers.assign(user_context, n, storage_ptr->data()); + } +} + +void StringTable::assign(void *user_context, size_t index, const char *str, size_t length) { + if (length == 0) { + length = strlen(str); + } + if (index < contents.size()) { + StringStorage *storage_ptr = static_cast(contents[index]); + storage_ptr->assign(user_context, str, length); + pointers.assign(user_context, index, storage_ptr->data()); + } +} + +void StringTable::append(void *user_context, const char *str, size_t length) { + StringStorage *storage_ptr = StringStorage::create(user_context, contents.current_allocator()); + storage_ptr->assign(user_context, str, length); + contents.append(user_context, storage_ptr); + pointers.append(user_context, storage_ptr->data()); +} + +void StringTable::prepend(void *user_context, const char *str, size_t length) { + StringStorage *storage_ptr = StringStorage::create(user_context, contents.current_allocator()); + storage_ptr->assign(user_context, str, length); + contents.prepend(user_context, storage_ptr); + pointers.prepend(user_context, storage_ptr->data()); +} + +size_t StringTable::parse(void *user_context, const char *str, const char *delim) { + if (StringUtils::is_empty(str)) { + return 0; + } + + size_t delim_length = strlen(delim); + size_t total_length = strlen(str); + size_t entry_count = StringUtils::count_tokens(str, delim); + if (entry_count < 1) { + return 0; + } + + resize(user_context, entry_count); + + // save each entry into the table + size_t index = 0; + const char *ptr = str; + while (!StringUtils::is_empty(ptr) && (index < entry_count)) { + size_t ptr_offset = ptr - str; + const char *next_delim = strstr(ptr, delim); + size_t token_length = (next_delim == nullptr) ? (total_length - ptr_offset) : (next_delim - ptr); + if (token_length > 0 && index < contents.size()) { + StringStorage *storage_ptr = static_cast(contents[index]); + storage_ptr->assign(user_context, ptr, token_length); + pointers.assign(user_context, index, storage_ptr->data()); + ++index; + } + ptr = (next_delim != nullptr) ? (next_delim + delim_length) : nullptr; + } + return entry_count; +} + +bool StringTable::contains(const char *str) const { + if (StringUtils::is_empty(str)) { + return false; + } + for (size_t n = 0; n < contents.size(); ++n) { + StringStorage *storage_ptr = static_cast(contents[n]); + if (storage_ptr->contains(str)) { + return true; + } + } + + return false; +} + +const char **StringTable::data() const { + return reinterpret_cast(pointers.data()); +} + +// -- + +} // namespace Internal +} // namespace Runtime +} // namespace Halide + +#endif // HALIDE_RUNTIME_STRING_STORAGE_H diff --git a/src/runtime/matlab.cpp b/src/runtime/matlab.cpp deleted file mode 100644 index 959d69d56476..000000000000 --- a/src/runtime/matlab.cpp +++ /dev/null @@ -1,528 +0,0 @@ -#include "HalideRuntime.h" -#include "printer.h" - -#ifndef MX_API_VER -#define MX_API_VER 0x07040000 -#endif - -struct mxArray; - -// It is important to have the mex function pointer definitions in a -// namespace to avoid silently conflicting symbols with matlab at -// runtime. -namespace Halide { -namespace Runtime { -namespace mex { - -// Define a few things from mex.h that we need to grab the mex APIs -// from matlab. - -enum { TMW_NAME_LENGTH_MAX = 64 }; -enum { mxMAXNAM = TMW_NAME_LENGTH_MAX }; - -typedef bool mxLogical; -typedef int16_t mxChar; - -enum mxClassID { - mxUNKNOWN_CLASS = 0, - mxCELL_CLASS, - mxSTRUCT_CLASS, - mxLOGICAL_CLASS, - mxCHAR_CLASS, - mxVOID_CLASS, - mxDOUBLE_CLASS, - mxSINGLE_CLASS, - mxINT8_CLASS, - mxUINT8_CLASS, - mxINT16_CLASS, - mxUINT16_CLASS, - mxINT32_CLASS, - mxUINT32_CLASS, - mxINT64_CLASS, - mxUINT64_CLASS, - mxFUNCTION_CLASS, - mxOPAQUE_CLASS, - mxOBJECT_CLASS, -#ifdef BITS_32 - mxINDEX_CLASS = mxUINT32_CLASS, -#else - mxINDEX_CLASS = mxUINT64_CLASS, -#endif - - mxSPARSE_CLASS = mxVOID_CLASS -}; - -enum mxComplexity { - mxREAL = 0, - mxCOMPLEX -}; - -#ifdef BITS_32 -typedef int mwSize; -typedef int mwIndex; -typedef int mwSignedIndex; -#else -typedef size_t mwSize; -typedef size_t mwIndex; -typedef ptrdiff_t mwSignedIndex; -#endif - -typedef void (*mex_exit_fn)(); - -// Declare function pointers for the mex APIs. -#define MEX_FN(ret, func, args) ret(*func) args; // NOLINT(bugprone-macro-parentheses) -#include "mex_functions.h" -#undef MEX_FN - -// Given a halide type code and bit width, find the equivalent matlab class ID. -WEAK mxClassID get_class_id(int32_t type_code, int32_t type_bits) { - switch (type_code) { - case halide_type_int: - switch (type_bits) { - case 1: - return mxLOGICAL_CLASS; - case 8: - return mxINT8_CLASS; - case 16: - return mxINT16_CLASS; - case 32: - return mxINT32_CLASS; - case 64: - return mxINT64_CLASS; - } - return mxUNKNOWN_CLASS; - case halide_type_uint: - switch (type_bits) { - case 1: - return mxLOGICAL_CLASS; - case 8: - return mxUINT8_CLASS; - case 16: - return mxUINT16_CLASS; - case 32: - return mxUINT32_CLASS; - case 64: - return mxUINT64_CLASS; - } - return mxUNKNOWN_CLASS; - case halide_type_float: - switch (type_bits) { - case 32: - return mxSINGLE_CLASS; - case 64: - return mxDOUBLE_CLASS; - } - return mxUNKNOWN_CLASS; - } - return mxUNKNOWN_CLASS; -} - -// Convert a matlab class ID to a string. -WEAK const char *get_class_name(mxClassID id) { - switch (id) { - case mxCELL_CLASS: - return "cell"; - case mxSTRUCT_CLASS: - return "struct"; - case mxLOGICAL_CLASS: - return "logical"; - case mxCHAR_CLASS: - return "char"; - case mxVOID_CLASS: - return "void"; - case mxDOUBLE_CLASS: - return "double"; - case mxSINGLE_CLASS: - return "single"; - case mxINT8_CLASS: - return "int8"; - case mxUINT8_CLASS: - return "uint8"; - case mxINT16_CLASS: - return "int16"; - case mxUINT16_CLASS: - return "uint16"; - case mxINT32_CLASS: - return "int32"; - case mxUINT32_CLASS: - return "uint32"; - case mxINT64_CLASS: - return "int64"; - case mxUINT64_CLASS: - return "uint64"; - case mxFUNCTION_CLASS: - return "function"; - case mxOPAQUE_CLASS: - return "opaque"; - case mxOBJECT_CLASS: - return "object"; - default: - return "unknown"; - } -} - -// Get the real data pointer from an mxArray. -template -ALWAYS_INLINE T *get_data(mxArray *a) { - return (T *)mxGetData(a); -} -template -ALWAYS_INLINE const T *get_data(const mxArray *a) { - return (const T *)mxGetData(a); -} - -// Search for a symbol in the calling process (i.e. matlab). -template -ALWAYS_INLINE T get_mex_symbol(void *user_context, const char *name, bool required) { - T s = (T)halide_get_symbol(name); - if (required && s == nullptr) { - error(user_context) << "mex API not found: " << name << "\n"; - return nullptr; - } - return s; -} - -// Provide Matlab API version agnostic wrappers for version specific APIs. -ALWAYS_INLINE size_t get_number_of_dimensions(const mxArray *a) { - if (mxGetNumberOfDimensions_730) { - return mxGetNumberOfDimensions_730(a); - } else { - return mxGetNumberOfDimensions_700(a); - } -} - -ALWAYS_INLINE size_t get_dimension(const mxArray *a, size_t n) { - if (mxGetDimensions_730) { - return mxGetDimensions_730(a)[n]; - } else { - return mxGetDimensions_700(a)[n]; - } -} - -ALWAYS_INLINE mxArray *create_numeric_matrix(size_t M, size_t N, mxClassID type, mxComplexity complexity) { - if (mxCreateNumericMatrix_730) { - return mxCreateNumericMatrix_730(M, N, type, complexity); - } else { - return mxCreateNumericMatrix_700(M, N, type, complexity); - } -} - -} // namespace mex -} // namespace Runtime -} // namespace Halide - -using namespace Halide::Runtime::mex; - -extern "C" { - -WEAK void halide_matlab_describe_pipeline(stringstream &desc, const halide_filter_metadata_t *metadata) { - desc << "int " << metadata->name << "("; - for (int i = 0; i < metadata->num_arguments; i++) { - const halide_filter_argument_t *arg = &metadata->arguments[i]; - if (i > 0) { - desc << ", "; - } - if (arg->kind == halide_argument_kind_output_buffer) { - desc << "out "; - } - if (arg->kind == halide_argument_kind_output_buffer || - arg->kind == halide_argument_kind_input_buffer) { - desc << arg->dimensions << "d "; - } else if (arg->kind == halide_argument_kind_input_scalar) { - desc << "scalar "; - } - desc << get_class_name(get_class_id(arg->type.code, arg->type.bits)); - desc << " '" << arg->name << "'"; - } - desc << ")"; -} - -WEAK void halide_matlab_note_pipeline_description(void *user_context, const halide_filter_metadata_t *metadata) { - stringstream desc(user_context); - desc << "Note pipeline definition:\n"; - halide_matlab_describe_pipeline(desc, metadata); - halide_print(user_context, desc.str()); -} - -WEAK void halide_matlab_error(void *user_context, const char *msg) { - // Note that mexErrMsg/mexErrMsgIdAndTxt crash Matlab. It seems to - // be a common problem, those APIs seem to be very fragile. - stringstream error_msg(user_context); - error_msg << "\nHalide Error: " << msg; - mexWarnMsgTxt(error_msg.str()); -} - -WEAK void halide_matlab_print(void *, const char *msg) { - mexWarnMsgTxt(msg); -} - -WEAK int halide_matlab_init(void *user_context) { - // Assume that if mexWarnMsgTxt exists, we've already attempted initialization. - if (mexWarnMsgTxt != nullptr) { - return halide_error_code_success; - } - -// clang-format off -#define MEX_FN(ret, func, args) func = get_mex_symbol(user_context, #func, true); // NOLINT(bugprone-macro-parentheses) -#define MEX_FN_700(ret, func, func_700, args) func_700 = get_mex_symbol(user_context, #func, false); // NOLINT(bugprone-macro-parentheses) -#define MEX_FN_730(ret, func, func_730, args) func_730 = get_mex_symbol(user_context, #func_730, false); // NOLINT(bugprone-macro-parentheses) -#include "mex_functions.h" -#undef MEX_FN_730 -#undef MEX_FN_700 -#undef MEX_FN - // clang-format on - - if (!mexWarnMsgTxt) { - return halide_error_code_matlab_init_failed; - } - - // Set up Halide's printing to go through Matlab. Also, don't exit - // on error. We don't just replace halide_error/halide_printf, - // because they'd have to be weak here, and there would be no - // guarantee that we would get this version (and not the standard - // one). - halide_set_custom_print(halide_matlab_print); - halide_set_error_handler(halide_matlab_error); - - return halide_error_code_success; -} - -// Convert a matlab mxArray to a Halide halide_buffer_t, with a specific number of dimensions. -WEAK int halide_matlab_array_to_halide_buffer_t(void *user_context, - const mxArray *arr, - const halide_filter_argument_t *arg, - halide_buffer_t *buf) { - - if (mxIsComplex(arr)) { - error(user_context) << "Complex argument not supported for parameter " << arg->name << ".\n"; - return halide_error_code_matlab_bad_param_type; - } - - int dim_count = get_number_of_dimensions(arr); - int expected_dims = arg->dimensions; - - // Validate that the data type of a buffer matches exactly. - mxClassID arg_class_id = get_class_id(arg->type.code, arg->type.bits); - mxClassID class_id = mxGetClassID(arr); - if (class_id != arg_class_id) { - error(user_context) << "Expected type of class " << get_class_name(arg_class_id) - << " for argument " << arg->name - << ", got class " << get_class_name(class_id) << ".\n"; - return halide_error_code_matlab_bad_param_type; - } - // Validate that the dimensionality matches. Matlab is wierd - // because matrices always have at least 2 dimensions, and it - // truncates trailing dimensions of extent 1. So, the only way - // to have an error here is to have more dimensions with - // extent != 1 than the Halide pipeline expects. - while (dim_count > 0 && get_dimension(arr, dim_count - 1) == 1) { - dim_count--; - } - if (dim_count > expected_dims) { - error(user_context) << "Expected array of rank " << expected_dims - << " for argument " << arg->name - << ", got array of rank " << dim_count << ".\n"; - return halide_error_code_matlab_bad_param_type; - } - - buf->host = (uint8_t *)mxGetData(arr); - buf->type = arg->type; - buf->dimensions = arg->dimensions; - buf->set_host_dirty(true); - - for (int i = 0; i < dim_count && i < expected_dims; i++) { - buf->dim[i].extent = static_cast(get_dimension(arr, i)); - } - - // Add back the dimensions with extent 1. - for (int i = 2; i < expected_dims; i++) { - if (buf->dim[i].extent == 0) { - buf->dim[i].extent = 1; - } - } - - // Compute dense strides. - buf->dim[0].stride = 1; - for (int i = 1; i < expected_dims; i++) { - buf->dim[i].stride = buf->dim[i - 1].extent * buf->dim[i - 1].stride; - } - - return halide_error_code_success; -} - -// Convert a matlab mxArray to a scalar. -WEAK int halide_matlab_array_to_scalar(void *user_context, - const mxArray *arr, const halide_filter_argument_t *arg, void *scalar) { - if (mxIsComplex(arr)) { - error(user_context) << "Complex argument not supported for parameter " << arg->name << ".\n"; - return halide_error_code_generic_error; - } - - // Validate that the mxArray has all dimensions of extent 1. - int dim_count = get_number_of_dimensions(arr); - for (int i = 0; i < dim_count; i++) { - if (get_dimension(arr, i) != 1) { - error(user_context) << "Expected scalar argument for parameter " << arg->name << ".\n"; - return halide_error_code_matlab_bad_param_type; - } - } - if (!mxIsLogical(arr) && !mxIsNumeric(arr)) { - error(user_context) << "Expected numeric argument for scalar parameter " << arg->name - << ", got " << get_class_name(mxGetClassID(arr)) << ".\n"; - return halide_error_code_matlab_bad_param_type; - } - - double value = mxGetScalar(arr); - int32_t type_code = arg->type.code; - int32_t type_bits = arg->type.bits; - - if (type_code == halide_type_int) { - switch (type_bits) { - case 1: - *reinterpret_cast(scalar) = value != 0; - return halide_error_code_success; - case 8: - *reinterpret_cast(scalar) = static_cast(value); - return halide_error_code_success; - case 16: - *reinterpret_cast(scalar) = static_cast(value); - return halide_error_code_success; - case 32: - *reinterpret_cast(scalar) = static_cast(value); - return halide_error_code_success; - case 64: - *reinterpret_cast(scalar) = static_cast(value); - return halide_error_code_success; - } - } else if (type_code == halide_type_uint) { - switch (type_bits) { - case 1: - *reinterpret_cast(scalar) = value != 0; - return halide_error_code_success; - case 8: - *reinterpret_cast(scalar) = static_cast(value); - return halide_error_code_success; - case 16: - *reinterpret_cast(scalar) = static_cast(value); - return halide_error_code_success; - case 32: - *reinterpret_cast(scalar) = static_cast(value); - return halide_error_code_success; - case 64: - *reinterpret_cast(scalar) = static_cast(value); - return halide_error_code_success; - } - } else if (type_code == halide_type_float) { - switch (type_bits) { - case 32: - *reinterpret_cast(scalar) = static_cast(value); - return halide_error_code_success; - case 64: - *reinterpret_cast(scalar) = static_cast(value); - return halide_error_code_success; - } - } else if (type_code == halide_type_handle) { - error(user_context) << "Parameter " << arg->name << " is of a type not supported by Matlab.\n"; - return halide_error_code_matlab_bad_param_type; - } - error(user_context) << "Halide metadata for " << arg->name << " contained invalid or unrecognized type description.\n"; - return halide_error_code_internal_error; -} - -WEAK int halide_matlab_call_pipeline(void *user_context, - int (*pipeline)(void **args), const halide_filter_metadata_t *metadata, - int nlhs, mxArray **plhs, int nrhs, const mxArray **prhs) { - - int init_result = halide_matlab_init(user_context); - if (init_result != 0) { - return init_result; - } - - int32_t result_storage; - int32_t *result_ptr = &result_storage; - if (nlhs > 0) { - plhs[0] = create_numeric_matrix(1, 1, mxINT32_CLASS, mxREAL); - result_ptr = get_data(plhs[0]); - } - int32_t &result = *result_ptr; - - // Set result to failure until proven otherwise. - result = halide_error_code_generic_error; - - // Validate the number of arguments is correct. - if (nrhs != metadata->num_arguments) { - if (nrhs > 0) { - // Only report an actual error if there were any arguments at all. - error(user_context) << "Expected " << metadata->num_arguments - << " arguments for Halide pipeline " << metadata->name - << ", got " << nrhs << ".\n"; - } - halide_matlab_note_pipeline_description(user_context, metadata); - return result; - } - - // Validate the LHS has zero or one argument. - if (nlhs > 1) { - error(user_context) << "Expected zero or one return value for Halide pipeline " << metadata->name - << ", got " << nlhs << ".\n"; - halide_matlab_note_pipeline_description(user_context, metadata); - return result; - } - - void **args = (void **)__builtin_alloca(nrhs * sizeof(void *)); - for (int i = 0; i < nrhs; i++) { - const mxArray *arg = prhs[i]; - const halide_filter_argument_t *arg_metadata = &metadata->arguments[i]; - - if (arg_metadata->kind == halide_argument_kind_input_buffer || - arg_metadata->kind == halide_argument_kind_output_buffer) { - halide_buffer_t *buf = (halide_buffer_t *)__builtin_alloca(sizeof(halide_buffer_t)); - memset(buf, 0, sizeof(halide_buffer_t)); - buf->dim = (halide_dimension_t *)__builtin_alloca(sizeof(halide_dimension_t) * arg_metadata->dimensions); - memset(buf->dim, 0, sizeof(halide_dimension_t) * arg_metadata->dimensions); - result = halide_matlab_array_to_halide_buffer_t(user_context, arg, arg_metadata, buf); - if (result != 0) { - halide_matlab_note_pipeline_description(user_context, metadata); - return result; - } - args[i] = buf; - } else { - size_t size_bytes = max(8, (arg_metadata->type.bits + 7) / 8); - void *scalar = __builtin_alloca(size_bytes); - memset(scalar, 0, size_bytes); - result = halide_matlab_array_to_scalar(user_context, arg, arg_metadata, scalar); - if (result != 0) { - halide_matlab_note_pipeline_description(user_context, metadata); - return result; - } - args[i] = scalar; - } - } - - result = pipeline(args); - - // Copy any GPU resident output buffers back to the CPU before returning. - for (int i = 0; i < nrhs; i++) { - const halide_filter_argument_t *arg_metadata = &metadata->arguments[i]; - - if (arg_metadata->kind == halide_argument_kind_output_buffer) { - halide_buffer_t *buf = (halide_buffer_t *)args[i]; - if ((result = halide_copy_to_host(user_context, buf)) != 0) { - error(user_context) << "halide_matlab_call_pipeline: halide_copy_to_host failed.\n"; - return result; - } - } - if (arg_metadata->kind == halide_argument_kind_input_buffer || - arg_metadata->kind == halide_argument_kind_output_buffer) { - halide_buffer_t *buf = (halide_buffer_t *)args[i]; - if ((result = halide_device_free(user_context, buf)) != 0) { - error(user_context) << "halide_matlab_call_pipeline: halide_device_free failed.\n"; - return result; - } - } - } - - return result; -} - -} // extern "C" diff --git a/src/runtime/metadata.cpp b/src/runtime/metadata.cpp deleted file mode 100644 index b90ba174a146..000000000000 --- a/src/runtime/metadata.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include "HalideRuntime.h" - -namespace Halide { -namespace Runtime { -namespace Internal { - -// This is unused and expected to be optimized away; it exists solely to ensure -// that the halide_filter_metadata_t type is in the runtime module, so that -// Codegen_LLVM can access its description. -WEAK const halide_filter_metadata_t *unused_function_to_get_halide_filter_metadata_t_declared() { - return nullptr; -} - -} // namespace Internal -} // namespace Runtime -} // namespace Halide diff --git a/src/runtime/mex_functions.h b/src/runtime/mex_functions.h deleted file mode 100644 index e1ca89399361..000000000000 --- a/src/runtime/mex_functions.h +++ /dev/null @@ -1,203 +0,0 @@ -// This file intentionally does not use include guards!! -// The intended usage of this file is to define MEX_FN to do something -// useful with a mex function, and then include this file. This file #undefs -// MEX_FN after it is done. This file contains 3 types of functions: -// -// - MEX_FN(ret, func, args): A function with return type 'ret', name 'func', -// and arguments 'args'. -// - MEX_FN_700(ret, func, func_700, args): Similar to MEX_FN, but func_700 is -// the name of the function with _700 appended. This is only used for -// Matlab 7.0 API functions. -// - MEX_FN_730(ret, func, func_730, args): Similar to MEX_FN_700, but for the -// Matlab 7.3 API. - -// Provide default no-op definitions for the 3 macros if they don't already -// exist. -#ifndef MEX_FN -#define MEX_FN(ret, func, args) -#endif - -#ifndef MEX_FN_730 -#define MEX_FN_730(ret, func, func_730, args) MEX_FN(ret, func_730, args) -#endif - -#ifndef MEX_FN_700 -#define MEX_FN_700(ret, func, func_700, args) MEX_FN(ret, func_700, args) -#endif - -// mex.h -//MEX_FN(int, mexPrintf, (const char*, ...)); -//MEX_FN(void, mexErrMsgTxt, (const char*)); -//MEX_FN(void, mexErrMsgIdAndTxt, (const char *, const char*, ...)); -MEX_FN(void, mexWarnMsgTxt, (const char *)); -//MEX_FN(void, mexWarnMsgIdAndTxt, (const char *, const char*, ...)); -//MEX_FN(void, mexMakeArrayPersistent, (const mxArray*)); -//MEX_FN(void, mexMakeMemoryPersistent, (void *ptr)); -//MEX_FN(int, mexSet, (double, const char*, mxArray*)); -//MEX_FN(const mxArray*, mexGet, (double, const char*)); -//MEX_FN(int, mexCallMATLAB, (int, mxArray**, int, const mxArray**, const char *)); -//MEX_FN(mxArray*, mexCallMATLABWithTrap, (int, mxArray**, int, const mxArray**, const char *)); -//MEX_FN(void, mexSetTrapFlag, (int)); -//MEX_FN(void, mexPrintAssertion, (const char*, const char*, int, const char*)); -//MEX_FN(bool, mexIsGlobal, (const mxArray*)); -//MEX_FN(int, mexPutVariable, (const char*, const char*, const mxArray*)); -//MEX_FN(const mxArray*, mexGetVariablePtr, (const char*, const char*)); -//MEX_FN(mxArray*, mexGetVariable, (const char*, const char*)); -//MEX_FN(void, mexLock, (void)); -//MEX_FN(void, mexUnlock, (void)); -//MEX_FN(bool, mexIsLocked, (void)); -//MEX_FN(const char*, mexFunctionName, (void)); -//MEX_FN(int, mexEvalString, (const char*)); -//MEX_FN(mxArray*, mexEvalStringWithTrap, (const char*)); -//MEX_FN(int, mexAtExit, (mex_exit_fn)); - -// matrix.h -//MEX_FN(void*, mxMalloc, (size_t)); -//MEX_FN(void*, mxCalloc, (size_t, size_t)); -//MEX_FN(void, mxFree, (void*)); -//MEX_FN(void*, mxRealloc, (void*, size_t)); -MEX_FN_730(size_t, mxGetNumberOfDimensions, mxGetNumberOfDimensions_730, (const mxArray *)); -MEX_FN_700(int, mxGetNumberOfDimensions, mxGetNumberOfDimensions_700, (const mxArray *)); -MEX_FN_730(const size_t *, mxGetDimensions, mxGetDimensions_730, (const mxArray *)); -MEX_FN_700(const int *, mxGetDimensions, mxGetDimensions_700, (const mxArray *)); -//MEX_FN(size_t, mxGetM, (const mxArray*)); -//MEX_FN_730(size_t*, mxGetIr, mxGetIr_730, (const mxArray*)); -//MEX_FN_700(int*, mxGetIr, mxGetIr_700, (const mxArray*)); -//MEX_FN_730(size_t*, mxGetJc, mxGetJc_730, (const mxArray*)); -//MEX_FN_700(int*, mxGetJc, mxGetJc_700, (const mxArray*)); -//MEX_FN_730(size_t, mxGetNzmax, mxGetNzmax_730, (const mxArray*)); -//MEX_FN_700(int, mxGetNzmax, mxGetNzmax_700, (const mxArray*)); -//MEX_FN_730(void, mxSetNzmax, mxSetNzmax_730, (mxArray*, size_t)); -//MEX_FN_700(void, mxSetNzmax, mxSetNzmax_700, (mxArray*, int)); -//MEX_FN(const char*, mxGetFieldNameByNumber, (const mxArray*, int)); -//MEX_FN_730(mxArray*, mxGetFieldByNumber, mxGetFieldByNumber_730, (const mxArray*, size_t, int)); -//MEX_FN_700(mxArray*, mxGetFieldByNumber, mxGetFieldByNumber_700, (const mxArray*, int, int)); -//MEX_FN_730(mxArray*, mxGetCell, mxGetCell_730, (const mxArray*, size_t)); -//MEX_FN_700(mxArray*, mxGetCell, mxGetCell_700, (const mxArray*, int)); -MEX_FN(mxClassID, mxGetClassID, (const mxArray *)); -MEX_FN(void *, mxGetData, (const mxArray *)); -//MEX_FN(void, mxSetData, (mxArray*,void*)); -MEX_FN(bool, mxIsNumeric, (const mxArray *)); -//MEX_FN(bool, mxIsCell, (const mxArray*)); -MEX_FN(bool, mxIsLogical, (const mxArray *)); -//MEX_FN(bool, mxIsChar, (const mxArray*)); -//MEX_FN(bool, mxIsStruct, (const mxArray*)); -//MEX_FN(bool, mxIsOpaque, (const mxArray*)); -//MEX_FN(bool, mxIsFunctionHandle, (const mxArray*)); -//MEX_FN(bool, mxIsObject, (const mxArray*)); -//MEX_FN(void*, mxGetImagData, (const mxArray*)); -//MEX_FN(void, mxSetImagData, (mxArray*, void*)); -MEX_FN(bool, mxIsComplex, (const mxArray *)); -//MEX_FN(bool, mxIsSparse, (const mxArray*)); -//MEX_FN(bool, mxIsDouble, (const mxArray*)); -//MEX_FN(bool, mxIsSingle, (const mxArray*)); -//MEX_FN(bool, mxIsInt8, (const mxArray*)); -//MEX_FN(bool, mxIsUint8, (const mxArray*)); -//MEX_FN(bool, mxIsInt16, (const mxArray*)); -//MEX_FN(bool, mxIsUint16, (const mxArray*)); -//MEX_FN(bool, mxIsInt32, (const mxArray*)); -//MEX_FN(bool, mxIsUint32, (const mxArray*)); -//MEX_FN(bool, mxIsInt64, (const mxArray*)); -//MEX_FN(bool, mxIsUint64, (const mxArray*)); -//MEX_FN(size_t, mxGetNumberOfElements, (const mxArray*)); -//MEX_FN(double*, mxGetPr, (const mxArray*)); -//MEX_FN(void, mxSetPr, (mxArray*, double*)); -//MEX_FN(double*, mxGetPi, (const mxArray*)); -//MEX_FN(void, mxSetPi, (mxArray*, double*)); -//MEX_FN(mxChar*, mxGetChars, (const mxArray*)); -//MEX_FN(int, mxGetUserBits, (const mxArray*)); -//MEX_FN(void, mxSetUserBits, (mxArray*, int)); -MEX_FN(double, mxGetScalar, (const mxArray *)); -//MEX_FN(bool, mxIsFromGlobalWS, (const mxArray*)); -//MEX_FN(void, mxSetFromGlobalWS, (mxArray*, bool)); -//MEX_FN_730(void, mxSetM, mxSetM_730, (mxArray*, size_t)); -//MEX_FN_700(void, mxSetM, mxSetM_700, (mxArray*, int)); -//MEX_FN(size_t, mxGetN, (const mxArray*)); -//MEX_FN(bool, mxIsEmpty, (const mxArray*)); -//MEX_FN(int, mxGetFieldNumber, (const mxArray*, const char*)); -//MEX_FN_730(void, mxSetIr, mxSetIr_730, (mxArray*, size_t*)); -//MEX_FN_700(void, mxSetIr, mxSetIr_700, (mxArray*, int*)); -//MEX_FN_730(void, mxSetJc, mxSetJc_730, (mxArray*, size_t*)); -//MEX_FN_700(void, mxSetJc, mxSetJc_700, (mxArray*, int*)); -MEX_FN(size_t, mxGetElementSize, (const mxArray *)); -//MEX_FN_730(size_t, mxCalcSingleSubscript, mxCalcSingleSubscript_730, (const mxArray*, size_t, const size_t*)); -//MEX_FN_700(int, mxCalcSingleSubscript, mxCalcSingleSubscript_700, (const mxArray*, int, const int*)); -//MEX_FN(int, mxGetNumberOfFields, (const mxArray*)); -//MEX_FN_730(void, mxSetCell, mxSetCell_730, (mxArray*, size_t, mxArray*)); -//MEX_FN_700(void, mxSetCell, mxSetCell_700, (mxArray*, int, mxArray*)); -//MEX_FN_730(void, mxSetFieldByNumber, mxSetFieldByNumber_730, (mxArray*, size_t, int, mxArray*)); -//MEX_FN_700(void, mxSetFieldByNumber, mxSetFieldByNumber_700, (mxArray*, int, int, mxArray*)); -//MEX_FN_730(mxArray*, mxGetField, mxGetField_730, (const mxArray*, size_t, const char*)); -//MEX_FN_700(mxArray*, mxGetField, mxGetField_700, (const mxArray*, int, const char*)); -//MEX_FN_730(void, mxSetField, mxSetField_730, (mxArray*, size_t, const char*, mxArray*)); -//MEX_FN_700(void, mxSetField, mxSetField_700, (mxArray*, int, const char*, mxArray*)); -//MEX_FN_730(mxArray*, mxGetProperty, mxGetProperty_730, (const mxArray*, const size_t, const char*)); -//MEX_FN_700(mxArray*, mxGetProperty, mxGetProperty_700, (const mxArray*, const int, const char*)); -//MEX_FN_730(void, mxSetProperty, mxSetProperty_730, (mxArray*, size_t, const char*, const mxArray*)); -//MEX_FN_700(void, mxSetProperty, mxSetProperty_700, (mxArray*, int, const char*, const mxArray*)); -//MEX_FN(const char*, mxGetClassName, (const mxArray*)); -//MEX_FN(bool, mxIsClass, (const mxArray*, const char*)); -MEX_FN_730(mxArray *, mxCreateNumericMatrix, mxCreateNumericMatrix_730, (size_t, size_t, mxClassID, mxComplexity)); -MEX_FN_700(mxArray *, mxCreateNumericMatrix, mxCreateNumericMatrix_700, (int, int, mxClassID, mxComplexity)); -//MEX_FN_730(void, mxSetN, mxSetN_730, (mxArray*, size_t)); -//MEX_FN_700(void, mxSetN, mxSetN_700, (mxArray*, int)); -//MEX_FN_730(int, mxSetDimensions, mxSetDimensions_730, (mxArray*, const size_t*, size_t)); -//MEX_FN_700(int, mxSetDimensions, mxSetDimensions_700, (mxArray*, const int*, int)); -//MEX_FN(void, mxDestroyArray, (mxArray*)); -//MEX_FN_730(mxArray*, mxCreateNumericArray, mxCreateNumericArray_730, (size_t, const size_t*, mxClassID, mxComplexity)); -//MEX_FN_700(mxArray*, mxCreateNumericArray, mxCreateNumericArray_700, (int, const int*, mxClassID, mxComplexity)); -//MEX_FN_730(mxArray*, mxCreateCharArray, mxCreateCharArray_730, (size_t, const size_t*)); -//MEX_FN_700(mxArray*, mxCreateCharArray, mxCreateCharArray_700, (int, const int*)); -//MEX_FN_730(mxArray*, mxCreateDoubleMatrix, mxCreateDoubleMatrix_730, (size_t, size_t, mxComplexity)); -//MEX_FN_700(mxArray*, mxCreateDoubleMatrix, mxCreateDoubleMatrix_700, (int, int, mxComplexity)); -//MEX_FN(mxLogical*, mxGetLogicals, (const mxArray*)); -//MEX_FN_730(mxArray*, mxCreateLogicalArray, mxCreateLogicalArray_730, (size_t, const size_t*)); -//MEX_FN_700(mxArray*, mxCreateLogicalArray, mxCreateLogicalArray_700, (int, const int*)); -//MEX_FN_730(mxArray*, mxCreateLogicalMatrix, mxCreateLogicalMatrix_730, (size_t, size_t)); -//MEX_FN_700(mxArray*, mxCreateLogicalMatrix, mxCreateLogicalMatrix_700, (int, int)); -//MEX_FN(mxArray*, mxCreateLogicalScalar, (bool)); -//MEX_FN(bool, mxIsLogicalScalar, (const mxArray*)); -//MEX_FN(bool, mxIsLogicalScalarTrue, (const mxArray*)); -//MEX_FN(mxArray*, mxCreateDoubleScalar, (double)); -//MEX_FN_730(mxArray*, mxCreateSparse, mxCreateSparse_730, (size_t, size_t, size_t, mxComplexity)); -//MEX_FN_700(mxArray*, mxCreateSparse, mxCreateSparse_700, (int, int, int, mxComplexity)); -//MEX_FN_730(mxArray*, mxCreateSparseLogicalMatrix, mxCreateSparseLogicalMatrix_730, (size_t, size_t, size_t)); -//MEX_FN_700(mxArray*, mxCreateSparseLogicalMatrix, mxCreateSparseLogicalMatrix_700, (int, int, int)); -//MEX_FN_730(void, mxGetNChars, mxGetNChars_730, (const mxArray*, char*, size_t)); -//MEX_FN_700(void, mxGetNChars, mxGetNChars_700, (const mxArray*, char*, int)); -//MEX_FN_730(int, mxGetString, mxGetString_730, (const mxArray*, char*, size_t)); -//MEX_FN_700(int, mxGetString, mxGetString_700, (const mxArray*, char*, int)); -//MEX_FN(char*, mxArrayToString, (const mxArray*)); -//MEX_FN_730(mxArray*, mxCreateStringFromNChars, mxCreateStringFromNChars_730, (const char*, size_t)); -//MEX_FN_700(mxArray*, mxCreateStringFromNChars, mxCreateStringFromNChars_700, (const char*, int)); -//MEX_FN(mxArray*, mxCreateString, (const char*)); -//MEX_FN_730(mxArray*, mxCreateCharMatrixFromStrings, mxCreateCharMatrixFromStrings_730, (size_t, const char**)); -//MEX_FN_700(mxArray*, mxCreateCharMatrixFromStrings, mxCreateCharMatrixFromStrings_700, (int, const char**)); -//MEX_FN_730(mxArray*, mxCreateCellMatrix, mxCreateCellMatrix_730, (size_t, size_t)); -//MEX_FN_700(mxArray*, mxCreateCellMatrix, mxCreateCellMatrix_700, (int, int)); -//MEX_FN_730(mxArray*, mxCreateCellArray, mxCreateCellArray_730, (size_t, const size_t*)); -//MEX_FN_700(mxArray*, mxCreateCellArray, mxCreateCellArray_700, (int, const int*)); -//MEX_FN_730(mxArray*, mxCreateStructMatrix, mxCreateStructMatrix_730, (size_t, size_t, int, const char**)); -//MEX_FN_700(mxArray*, mxCreateStructMatrix, mxCreateStructMatrix_700, (int, int, int, const char**)); -//MEX_FN_730(mxArray*, mxCreateStructArray, mxCreateStructArray_730, (size_t, const size_t*, int, const char**)); -//MEX_FN_700(mxArray*, mxCreateStructArray, mxCreateStructArray_700, (int, const int*, int, const char**)); -//MEX_FN(mxArray*, mxDuplicateArray, (const mxArray*)); -//MEX_FN(int, mxSetClassName, (mxArray*, const char*)); -//MEX_FN(int, mxAddField, (mxArray*, const char*)); -//MEX_FN(void, mxRemoveField, (mxArray*, int)); -//MEX_FN(double, mxGetEps, (void)); -//MEX_FN(double, mxGetInf, (void)); -//MEX_FN(double, mxGetNaN, (void)); -//MEX_FN(bool, mxIsFinite, (double)); -//MEX_FN(bool, mxIsInf, (double)); -//MEX_FN(bool, mxIsNaN, (double)); - -#ifdef MEX_FN -#undef MEX_FN -#endif -#ifdef MEX_FN_730 -#undef MEX_FN_730 -#endif -#ifdef MEX_FN_700 -#undef MEX_FN_700 -#endif diff --git a/src/runtime/mini_d3d12.h b/src/runtime/mini_d3d12.h index d003b3514690..e829b997bbc3 100644 --- a/src/runtime/mini_d3d12.h +++ b/src/runtime/mini_d3d12.h @@ -715,8 +715,7 @@ _Post_equal_to_(pp) _Post_satisfies_(return == pp) void **IID_PPV_ARGS_Helper(T #define DECLARE_INTERFACE(iface) \ typedef interface iface { \ const struct iface##Vtbl FAR *lpVtbl; \ - } \ - iface; \ + } iface; \ typedef const struct iface##Vtbl iface##Vtbl; \ const struct iface##Vtbl #else @@ -725,8 +724,7 @@ _Post_equal_to_(pp) _Post_satisfies_(return == pp) void **IID_PPV_ARGS_Helper(T #define DECLARE_INTERFACE(iface) \ typedef interface iface { \ struct iface##Vtbl FAR *lpVtbl; \ - } \ - iface; \ + } iface; \ typedef struct iface##Vtbl iface##Vtbl; \ struct iface##Vtbl #endif @@ -6755,7 +6753,7 @@ interface IDXGIAdapter1 { #endif /* __IDXGIAdapter1_INTERFACE_DEFINED__ */ // NOTE(marcos): declaring CreateDXGIFactory "1" since it works on UWP as well -//HRESULT WINAPI CreateDXGIFactory1(REFIID riid, _COM_Outptr_ void **ppFactory); +// HRESULT WINAPI CreateDXGIFactory1(REFIID riid, _COM_Outptr_ void **ppFactory); typedef HRESULT(WINAPI *PFN_CREATEDXGIFACORY1)(REFIID riid, _COM_Outptr_ void **ppFactory); DEFINE_GUID(IID_IDXGIObject, 0xaec22fb8, 0x76f3, 0x4639, 0x9b, 0xe0, 0x28, 0xeb, 0x43, 0xa6, 0x7a, 0x2e); diff --git a/src/runtime/mini_hexagon_dma.h b/src/runtime/mini_hexagon_dma.h index 1aad24dbb4ec..55f1dea66974 100644 --- a/src/runtime/mini_hexagon_dma.h +++ b/src/runtime/mini_hexagon_dma.h @@ -134,7 +134,7 @@ typedef struct stDmaWrapper_DmaTransferSetup { void *pTcmDataBuf; /// Virtual address of the DDR Frame buffer . void *pFrameBuf; - //UBWC Format + // UBWC Format uint16 bIsFmtUbwc; /// Should the intermediate buffer be padded. This only apply for 8bit format sucha NV12, NV12-4R uint16 bUse16BitPaddingInL2; diff --git a/src/runtime/powerpc.ll b/src/runtime/powerpc.ll index 521bfe9a32b1..99c83716d6f4 100644 --- a/src/runtime/powerpc.ll +++ b/src/runtime/powerpc.ll @@ -2,7 +2,7 @@ declare <4 x float> @llvm.ppc.altivec.vrefp(<4 x float>) nounwind readnone declare <4 x float> @llvm.ppc.altivec.vrsqrtefp(<4 x float>) nounwind readnone define weak_odr float @fast_inverse_f32(float %x) readnone alwaysinline { - %vec = insertelement <4 x float> undef, float %x, i32 0 + %vec = insertelement <4 x float> poison, float %x, i32 0 %approx = tail call <4 x float> @llvm.ppc.altivec.vrefp(<4 x float> %vec) %result = extractelement <4 x float> %approx, i32 0 ret float %result @@ -14,7 +14,7 @@ define weak_odr <4 x float> @fast_inverse_f32x4(<4 x float> %x) readnone alwaysi } define weak_odr float @fast_inverse_sqrt_f32(float %x) readnone alwaysinline { - %vec = insertelement <4 x float> undef, float %x, i32 0 + %vec = insertelement <4 x float> poison, float %x, i32 0 %approx = tail call <4 x float> @llvm.ppc.altivec.vrsqrtefp(<4 x float> %vec) %result = extractelement <4 x float> %approx, i32 0 ret float %result diff --git a/src/runtime/runtime_api.cpp b/src/runtime/runtime_api.cpp index 3d36bd7b67e7..c06141dcf267 100644 --- a/src/runtime/runtime_api.cpp +++ b/src/runtime/runtime_api.cpp @@ -18,6 +18,8 @@ // Can be generated via the following: // cat src/runtime/runtime_internal.h src/runtime/HalideRuntime*.h | grep "^[^ ][^(]*halide_[^ ]*(" | grep -v '#define' | sed "s/[^(]*halide/halide/" | sed "s/(.*//" | sed "s/^h/ \(void *)\&h/" | sed "s/$/,/" | sort | uniq +extern "C" void halide_unused_force_include_types(); + extern "C" __attribute__((used)) void *halide_runtime_api_functions[] = { (void *)&halide_buffer_copy, (void *)&halide_buffer_to_string, @@ -111,7 +113,6 @@ extern "C" __attribute__((used)) void *halide_runtime_api_functions[] = { (void *)&halide_join_thread, (void *)&halide_load_library, (void *)&halide_malloc, - (void *)&halide_matlab_call_pipeline, (void *)&halide_memoization_cache_cleanup, (void *)&halide_memoization_cache_evict, (void *)&halide_memoization_cache_lookup, @@ -210,4 +211,5 @@ extern "C" __attribute__((used)) void *halide_runtime_api_functions[] = { (void *)&halide_d3d12compute_finalize_kernels, (void *)&halide_d3d12compute_release_context, (void *)&halide_d3d12compute_run, + (void *)&halide_unused_force_include_types, }; diff --git a/src/runtime/runtime_internal.h b/src/runtime/runtime_internal.h index 9dd3572a2bb9..2801f9bfedc5 100644 --- a/src/runtime/runtime_internal.h +++ b/src/runtime/runtime_internal.h @@ -1,9 +1,13 @@ #ifndef HALIDE_RUNTIME_INTERNAL_H #define HALIDE_RUNTIME_INTERNAL_H +#ifdef COMPILING_HALIDE_RUNTIME_TESTS +// Only allowed if building Halide runtime tests ... since they use system compiler which may be GCC or MSVS +#else #if __STDC_HOSTED__ #error "Halide runtime files must be compiled with clang in freestanding mode." #endif +#endif #ifdef __UINT8_TYPE__ typedef __INT64_TYPE__ int64_t; @@ -92,6 +96,7 @@ int strncmp(const char *s, const char *t, size_t n); size_t strlen(const char *s); const char *strchr(const char *s, int c); void *memcpy(void *s1, const void *s2, size_t n); +void *memmove(void *dest, const void *src, size_t n); int memcmp(const void *s1, const void *s2, size_t n); void *memset(void *s, int val, size_t n); // Use fopen+fileno+fclose instead of open+close - the value of the @@ -164,11 +169,6 @@ WEAK int halide_device_and_host_free(void *user_context, struct halide_buffer_t struct halide_filter_metadata_t; -struct mxArray; -WEAK int halide_matlab_call_pipeline(void *user_context, - int (*pipeline)(void **args), const halide_filter_metadata_t *metadata, - int nlhs, mxArray **plhs, int nrhs, const mxArray **prhs); - WEAK int halide_trace_helper(void *user_context, const char *func, void *value, int *coords, diff --git a/src/runtime/wasm_math.ll b/src/runtime/wasm_math.ll index aaa4d0778057..f39705e12799 100644 --- a/src/runtime/wasm_math.ll +++ b/src/runtime/wasm_math.ll @@ -45,8 +45,8 @@ define weak_odr <4 x float> @fast_inverse_sqrt_f32x4(<4 x float> %x) nounwind al ; i8 -> i16 define weak_odr <8 x i16> @extmul_low_s_v8i16(<16 x i8> %v1, <16 x i8> %v2) nounwind alwaysinline { - %low1 = shufflevector <16 x i8> %v1, <16 x i8> undef, <8 x i32> - %low2 = shufflevector <16 x i8> %v2, <16 x i8> undef, <8 x i32> + %low1 = shufflevector <16 x i8> %v1, <16 x i8> poison, <8 x i32> + %low2 = shufflevector <16 x i8> %v2, <16 x i8> poison, <8 x i32> %extended1 = sext <8 x i8> %low1 to <8 x i16> %extended2 = sext <8 x i8> %low2 to <8 x i16> %a = mul <8 x i16> %extended1, %extended2 @@ -54,8 +54,8 @@ define weak_odr <8 x i16> @extmul_low_s_v8i16(<16 x i8> %v1, <16 x i8> %v2) noun } define weak_odr <8 x i16> @extmul_high_s_v8i16(<16 x i8> %v1, <16 x i8> %v2) nounwind alwaysinline { - %high1 = shufflevector <16 x i8> %v1, <16 x i8> undef, <8 x i32> - %high2 = shufflevector <16 x i8> %v2, <16 x i8> undef, <8 x i32> + %high1 = shufflevector <16 x i8> %v1, <16 x i8> poison, <8 x i32> + %high2 = shufflevector <16 x i8> %v2, <16 x i8> poison, <8 x i32> %extended1 = sext <8 x i8> %high1 to <8 x i16> %extended2 = sext <8 x i8> %high2 to <8 x i16> %a = mul <8 x i16> %extended1, %extended2 @@ -71,8 +71,8 @@ define weak_odr <16 x i16> @widening_mul_i8x16(<16 x i8> %x, <16 x i8> %y) nounw ; i16 -> i32 define weak_odr <4 x i32> @extmul_low_s_v4i32(<8 x i16> %v1, <8 x i16> %v2) nounwind alwaysinline { - %low1 = shufflevector <8 x i16> %v1, <8 x i16> undef, <4 x i32> - %low2 = shufflevector <8 x i16> %v2, <8 x i16> undef, <4 x i32> + %low1 = shufflevector <8 x i16> %v1, <8 x i16> poison, <4 x i32> + %low2 = shufflevector <8 x i16> %v2, <8 x i16> poison, <4 x i32> %extended1 = sext <4 x i16> %low1 to <4 x i32> %extended2 = sext <4 x i16> %low2 to <4 x i32> %a = mul <4 x i32> %extended1, %extended2 @@ -80,8 +80,8 @@ define weak_odr <4 x i32> @extmul_low_s_v4i32(<8 x i16> %v1, <8 x i16> %v2) noun } define weak_odr <4 x i32> @extmul_high_s_v4i32(<8 x i16> %v1, <8 x i16> %v2) nounwind alwaysinline { - %high1 = shufflevector <8 x i16> %v1, <8 x i16> undef, <4 x i32> - %high2 = shufflevector <8 x i16> %v2, <8 x i16> undef, <4 x i32> + %high1 = shufflevector <8 x i16> %v1, <8 x i16> poison, <4 x i32> + %high2 = shufflevector <8 x i16> %v2, <8 x i16> poison, <4 x i32> %extended1 = sext <4 x i16> %high1 to <4 x i32> %extended2 = sext <4 x i16> %high2 to <4 x i32> %a = mul <4 x i32> %extended1, %extended2 @@ -97,8 +97,8 @@ define weak_odr <8 x i32> @widening_mul_i16x8(<8 x i16> %x, <8 x i16> %y) nounwi ; i32 -> i64 define weak_odr <2 x i64> @extmul_low_s_v2i64(<4 x i32> %v1, <4 x i32> %v2) nounwind alwaysinline { - %low1 = shufflevector <4 x i32> %v1, <4 x i32> undef, <2 x i32> - %low2 = shufflevector <4 x i32> %v2, <4 x i32> undef, <2 x i32> + %low1 = shufflevector <4 x i32> %v1, <4 x i32> poison, <2 x i32> + %low2 = shufflevector <4 x i32> %v2, <4 x i32> poison, <2 x i32> %extended1 = sext <2 x i32> %low1 to <2 x i64> %extended2 = sext <2 x i32> %low2 to <2 x i64> %a = mul <2 x i64> %extended1, %extended2 @@ -106,8 +106,8 @@ define weak_odr <2 x i64> @extmul_low_s_v2i64(<4 x i32> %v1, <4 x i32> %v2) noun } define weak_odr <2 x i64> @extmul_high_s_v2i64(<4 x i32> %v1, <4 x i32> %v2) nounwind alwaysinline { - %high1 = shufflevector <4 x i32> %v1, <4 x i32> undef, <2 x i32> - %high2 = shufflevector <4 x i32> %v2, <4 x i32> undef, <2 x i32> + %high1 = shufflevector <4 x i32> %v1, <4 x i32> poison, <2 x i32> + %high2 = shufflevector <4 x i32> %v2, <4 x i32> poison, <2 x i32> %extended1 = sext <2 x i32> %high1 to <2 x i64> %extended2 = sext <2 x i32> %high2 to <2 x i64> %a = mul <2 x i64> %extended1, %extended2 @@ -123,8 +123,8 @@ define weak_odr <4 x i64> @widening_mul_i32x4(<4 x i32> %x, <4 x i32> %y) nounwi ; u8 -> u16 define weak_odr <8 x i16> @extmul_low_u_v8i16(<16 x i8> %v1, <16 x i8> %v2) nounwind alwaysinline { - %low1 = shufflevector <16 x i8> %v1, <16 x i8> undef, <8 x i32> - %low2 = shufflevector <16 x i8> %v2, <16 x i8> undef, <8 x i32> + %low1 = shufflevector <16 x i8> %v1, <16 x i8> poison, <8 x i32> + %low2 = shufflevector <16 x i8> %v2, <16 x i8> poison, <8 x i32> %extended1 = zext <8 x i8> %low1 to <8 x i16> %extended2 = zext <8 x i8> %low2 to <8 x i16> %a = mul <8 x i16> %extended1, %extended2 @@ -132,8 +132,8 @@ define weak_odr <8 x i16> @extmul_low_u_v8i16(<16 x i8> %v1, <16 x i8> %v2) noun } define weak_odr <8 x i16> @extmul_high_u_v8i16(<16 x i8> %v1, <16 x i8> %v2) nounwind alwaysinline { - %high1 = shufflevector <16 x i8> %v1, <16 x i8> undef, <8 x i32> - %high2 = shufflevector <16 x i8> %v2, <16 x i8> undef, <8 x i32> + %high1 = shufflevector <16 x i8> %v1, <16 x i8> poison, <8 x i32> + %high2 = shufflevector <16 x i8> %v2, <16 x i8> poison, <8 x i32> %extended1 = zext <8 x i8> %high1 to <8 x i16> %extended2 = zext <8 x i8> %high2 to <8 x i16> %a = mul <8 x i16> %extended1, %extended2 @@ -149,8 +149,8 @@ define weak_odr <16 x i16> @widening_mul_u8x16(<16 x i8> %x, <16 x i8> %y) nounw ; u16 -> u32 define weak_odr <4 x i32> @extmul_low_u_v4i32(<8 x i16> %v1, <8 x i16> %v2) nounwind alwaysinline { - %low1 = shufflevector <8 x i16> %v1, <8 x i16> undef, <4 x i32> - %low2 = shufflevector <8 x i16> %v2, <8 x i16> undef, <4 x i32> + %low1 = shufflevector <8 x i16> %v1, <8 x i16> poison, <4 x i32> + %low2 = shufflevector <8 x i16> %v2, <8 x i16> poison, <4 x i32> %extended1 = zext <4 x i16> %low1 to <4 x i32> %extended2 = zext <4 x i16> %low2 to <4 x i32> %a = mul <4 x i32> %extended1, %extended2 @@ -158,8 +158,8 @@ define weak_odr <4 x i32> @extmul_low_u_v4i32(<8 x i16> %v1, <8 x i16> %v2) noun } define weak_odr <4 x i32> @extmul_high_u_v4i32(<8 x i16> %v1, <8 x i16> %v2) nounwind alwaysinline { - %high1 = shufflevector <8 x i16> %v1, <8 x i16> undef, <4 x i32> - %high2 = shufflevector <8 x i16> %v2, <8 x i16> undef, <4 x i32> + %high1 = shufflevector <8 x i16> %v1, <8 x i16> poison, <4 x i32> + %high2 = shufflevector <8 x i16> %v2, <8 x i16> poison, <4 x i32> %extended1 = zext <4 x i16> %high1 to <4 x i32> %extended2 = zext <4 x i16> %high2 to <4 x i32> %a = mul <4 x i32> %extended1, %extended2 @@ -175,8 +175,8 @@ define weak_odr <8 x i32> @widening_mul_u16x8(<8 x i16> %x, <8 x i16> %y) nounwi ; u32 -> u64 define weak_odr <2 x i64> @extmul_low_u_v2i64(<4 x i32> %v1, <4 x i32> %v2) nounwind alwaysinline { - %low1 = shufflevector <4 x i32> %v1, <4 x i32> undef, <2 x i32> - %low2 = shufflevector <4 x i32> %v2, <4 x i32> undef, <2 x i32> + %low1 = shufflevector <4 x i32> %v1, <4 x i32> poison, <2 x i32> + %low2 = shufflevector <4 x i32> %v2, <4 x i32> poison, <2 x i32> %extended1 = zext <2 x i32> %low1 to <2 x i64> %extended2 = zext <2 x i32> %low2 to <2 x i64> %a = mul <2 x i64> %extended1, %extended2 @@ -184,8 +184,8 @@ define weak_odr <2 x i64> @extmul_low_u_v2i64(<4 x i32> %v1, <4 x i32> %v2) noun } define weak_odr <2 x i64> @extmul_high_u_v2i64(<4 x i32> %v1, <4 x i32> %v2) nounwind alwaysinline { - %high1 = shufflevector <4 x i32> %v1, <4 x i32> undef, <2 x i32> - %high2 = shufflevector <4 x i32> %v2, <4 x i32> undef, <2 x i32> + %high1 = shufflevector <4 x i32> %v1, <4 x i32> poison, <2 x i32> + %high2 = shufflevector <4 x i32> %v2, <4 x i32> poison, <2 x i32> %extended1 = zext <2 x i32> %high1 to <2 x i64> %extended2 = zext <2 x i32> %high2 to <2 x i64> %a = mul <2 x i64> %extended1, %extended2 @@ -207,29 +207,29 @@ declare <8 x i16> @llvm.wasm.narrow.signed.v8i16.v4i32(<4 x i32>, <4 x i32>) declare <8 x i16> @llvm.wasm.narrow.unsigned.v8i16.v4i32(<4 x i32>, <4 x i32>) define weak_odr <16 x i8> @saturating_narrow_i16x16_to_i8x16(<16 x i16> %x) nounwind alwaysinline { - %1 = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> - %2 = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %1 = shufflevector <16 x i16> %x, <16 x i16> poison, <8 x i32> + %2 = shufflevector <16 x i16> %x, <16 x i16> poison, <8 x i32> %3 = tail call <16 x i8> @llvm.wasm.narrow.signed.v16i8.v8i16(<8 x i16> %1, <8 x i16> %2) ret <16 x i8> %3 } define weak_odr <16 x i8> @saturating_narrow_i16x16_to_u8x16(<16 x i16> %x) nounwind alwaysinline { - %1 = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> - %2 = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %1 = shufflevector <16 x i16> %x, <16 x i16> poison, <8 x i32> + %2 = shufflevector <16 x i16> %x, <16 x i16> poison, <8 x i32> %3 = tail call <16 x i8> @llvm.wasm.narrow.unsigned.v16i8.v8i16(<8 x i16> %1, <8 x i16> %2) ret <16 x i8> %3 } define weak_odr <8 x i16> @saturating_narrow_i32x8_to_i16x8(<8 x i32> %x) nounwind alwaysinline { - %1 = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> - %2 = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %1 = shufflevector <8 x i32> %x, <8 x i32> poison, <4 x i32> + %2 = shufflevector <8 x i32> %x, <8 x i32> poison, <4 x i32> %3 = tail call <8 x i16> @llvm.wasm.narrow.signed.v8i16.v4i32(<4 x i32> %1, <4 x i32> %2) ret <8 x i16> %3 } define weak_odr <8 x i16> @saturating_narrow_i32x8_to_u16x8(<8 x i32> %x) nounwind alwaysinline { - %1 = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> - %2 = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %1 = shufflevector <8 x i32> %x, <8 x i32> poison, <4 x i32> + %2 = shufflevector <8 x i32> %x, <8 x i32> poison, <4 x i32> %3 = tail call <8 x i16> @llvm.wasm.narrow.unsigned.v8i16.v4i32(<4 x i32> %1, <4 x i32> %2) ret <8 x i16> %3 } @@ -245,8 +245,8 @@ define weak_odr <4 x double> @float_to_double(<4 x float> %x) nounwind alwaysinl ; i8 -> i16 define weak_odr <16 x i16> @extend_i8x16_to_i16x8(<16 x i8> %x) nounwind alwaysinline { - %1 = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> - %2 = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %1 = shufflevector <16 x i8> %x, <16 x i8> poison, <8 x i32> + %2 = shufflevector <16 x i8> %x, <16 x i8> poison, <8 x i32> %3 = sext <8 x i8> %1 to <8 x i16> %4 = sext <8 x i8> %2 to <8 x i16> %5 = shufflevector <8 x i16> %3, <8 x i16> %4, <16 x i32> @@ -256,8 +256,8 @@ define weak_odr <16 x i16> @extend_i8x16_to_i16x8(<16 x i8> %x) nounwind alwaysi ; u8 -> u16 define weak_odr <16 x i16> @extend_u8x16_to_u16x8(<16 x i8> %x) nounwind alwaysinline { - %1 = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> - %2 = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %1 = shufflevector <16 x i8> %x, <16 x i8> poison, <8 x i32> + %2 = shufflevector <16 x i8> %x, <16 x i8> poison, <8 x i32> %3 = zext <8 x i8> %1 to <8 x i16> %4 = zext <8 x i8> %2 to <8 x i16> %5 = shufflevector <8 x i16> %3, <8 x i16> %4, <16 x i32> @@ -267,8 +267,8 @@ define weak_odr <16 x i16> @extend_u8x16_to_u16x8(<16 x i8> %x) nounwind alwaysi ; i16 -> i32 define weak_odr <8 x i32> @extend_i16x8_to_i32x8(<8 x i16> %x) nounwind alwaysinline { - %1 = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> - %2 = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %1 = shufflevector <8 x i16> %x, <8 x i16> poison, <4 x i32> + %2 = shufflevector <8 x i16> %x, <8 x i16> poison, <4 x i32> %3 = sext <4 x i16> %1 to <4 x i32> %4 = sext <4 x i16> %2 to <4 x i32> %5 = shufflevector <4 x i32> %3, <4 x i32> %4, <8 x i32> @@ -278,8 +278,8 @@ define weak_odr <8 x i32> @extend_i16x8_to_i32x8(<8 x i16> %x) nounwind alwaysin ; u16 -> u32 define weak_odr <8 x i32> @extend_u16x8_to_u32x8(<8 x i16> %x) nounwind alwaysinline { - %1 = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> - %2 = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %1 = shufflevector <8 x i16> %x, <8 x i16> poison, <4 x i32> + %2 = shufflevector <8 x i16> %x, <8 x i16> poison, <4 x i32> %3 = zext <4 x i16> %1 to <4 x i32> %4 = zext <4 x i16> %2 to <4 x i32> %5 = shufflevector <4 x i32> %3, <4 x i32> %4, <8 x i32> @@ -289,8 +289,8 @@ define weak_odr <8 x i32> @extend_u16x8_to_u32x8(<8 x i16> %x) nounwind alwaysin ; i32 -> i64 define weak_odr <4 x i64> @extend_i32x4_to_i64x4(<4 x i32> %x) nounwind alwaysinline { - %1 = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> - %2 = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %1 = shufflevector <4 x i32> %x, <4 x i32> poison, <2 x i32> + %2 = shufflevector <4 x i32> %x, <4 x i32> poison, <2 x i32> %3 = sext <2 x i32> %1 to <2 x i64> %4 = sext <2 x i32> %2 to <2 x i64> %5 = shufflevector <2 x i64> %3, <2 x i64> %4, <4 x i32> @@ -300,8 +300,8 @@ define weak_odr <4 x i64> @extend_i32x4_to_i64x4(<4 x i32> %x) nounwind alwaysin ; u32 -> u64 define weak_odr <4 x i64> @extend_u32x4_to_u64x4(<4 x i32> %x) nounwind alwaysinline { - %1 = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> - %2 = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %1 = shufflevector <4 x i32> %x, <4 x i32> poison, <2 x i32> + %2 = shufflevector <4 x i32> %x, <4 x i32> poison, <2 x i32> %3 = zext <2 x i32> %1 to <2 x i64> %4 = zext <2 x i32> %2 to <2 x i64> %5 = shufflevector <2 x i64> %3, <2 x i64> %4, <4 x i32> diff --git a/src/runtime/x86.ll b/src/runtime/x86.ll index 31cea48ffcd5..5e6b5613e9f6 100644 --- a/src/runtime/x86.ll +++ b/src/runtime/x86.ll @@ -31,26 +31,48 @@ declare <16 x i8> @llvm.x86.sse2.packuswb.128(<8 x i16>, <8 x i16>) declare <8 x i16> @llvm.x86.sse2.packssdw.128(<4 x i32>, <4 x i32>) define weak_odr <16 x i8> @packsswbx16(<16 x i16> %arg) nounwind alwaysinline { - %1 = shufflevector <16 x i16> %arg, <16 x i16> undef, <8 x i32> - %2 = shufflevector <16 x i16> %arg, <16 x i16> undef, <8 x i32> + %1 = shufflevector <16 x i16> %arg, <16 x i16> poison, <8 x i32> + %2 = shufflevector <16 x i16> %arg, <16 x i16> poison, <8 x i32> %3 = tail call <16 x i8> @llvm.x86.sse2.packsswb.128(<8 x i16> %1, <8 x i16> %2) ret <16 x i8> %3 } define weak_odr <16 x i8> @packuswbx16(<16 x i16> %arg) nounwind alwaysinline { - %1 = shufflevector <16 x i16> %arg, <16 x i16> undef, <8 x i32> - %2 = shufflevector <16 x i16> %arg, <16 x i16> undef, <8 x i32> + %1 = shufflevector <16 x i16> %arg, <16 x i16> poison, <8 x i32> + %2 = shufflevector <16 x i16> %arg, <16 x i16> poison, <8 x i32> %3 = tail call <16 x i8> @llvm.x86.sse2.packuswb.128(<8 x i16> %1, <8 x i16> %2) ret <16 x i8> %3 } define weak_odr <8 x i16> @packssdwx8(<8 x i32> %arg) nounwind alwaysinline { - %1 = shufflevector <8 x i32> %arg, <8 x i32> undef, <4 x i32> - %2 = shufflevector <8 x i32> %arg, <8 x i32> undef, <4 x i32> < i32 4, i32 5, i32 6, i32 7> + %1 = shufflevector <8 x i32> %arg, <8 x i32> poison, <4 x i32> + %2 = shufflevector <8 x i32> %arg, <8 x i32> poison, <4 x i32> < i32 4, i32 5, i32 6, i32 7> %3 = tail call <8 x i16> @llvm.x86.sse2.packssdw.128(<4 x i32> %1, <4 x i32> %2) ret <8 x i16> %3 } +define weak_odr <8 x i32> @wmul_pmaddwd_avx2(<8 x i16> %a, <8 x i16> %b) nounwind alwaysinline { + %1 = zext <8 x i16> %a to <8 x i32> + %2 = zext <8 x i16> %b to <8 x i32> + %3 = bitcast <8 x i32> %1 to <16 x i16> + %4 = bitcast <8 x i32> %2 to <16 x i16> + %res = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %3, <16 x i16> %4) + ret <8 x i32> %res +} + +declare <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16>, <16 x i16>) nounwind readnone + +define weak_odr <4 x i32> @wmul_pmaddwd_sse2(<4 x i16> %a, <4 x i16> %b) nounwind alwaysinline { + %1 = zext <4 x i16> %a to <4 x i32> + %2 = zext <4 x i16> %b to <4 x i32> + %3 = bitcast <4 x i32> %1 to <8 x i16> + %4 = bitcast <4 x i32> %2 to <8 x i16> + %res = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %3, <8 x i16> %4) + ret <4 x i32> %res +} + +declare <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16>, <8 x i16>) nounwind readnone + define weak_odr <4 x float> @sqrt_f32x4(<4 x float> %x) nounwind uwtable readnone alwaysinline { %1 = tail call <4 x float> @llvm.x86.sse.sqrt.ps(<4 x float> %x) nounwind ret <4 x float> %1 @@ -83,7 +105,7 @@ define weak_odr <2 x double> @abs_f64x2(<2 x double> %x) nounwind uwtable readno declare <4 x float> @llvm.x86.sse.rcp.ss(<4 x float>) nounwind readnone define weak_odr float @fast_inverse_f32(float %x) nounwind uwtable readnone alwaysinline { - %vec = insertelement <4 x float> undef, float %x, i32 0 + %vec = insertelement <4 x float> poison, float %x, i32 0 %approx = tail call <4 x float> @llvm.x86.sse.rcp.ss(<4 x float> %vec) %result = extractelement <4 x float> %approx, i32 0 ret float %result @@ -99,7 +121,7 @@ define weak_odr <4 x float> @fast_inverse_f32x4(<4 x float> %x) nounwind uwtable declare <4 x float> @llvm.x86.sse.rsqrt.ss(<4 x float>) nounwind readnone define weak_odr float @fast_inverse_sqrt_f32(float %x) nounwind uwtable readnone alwaysinline { - %vec = insertelement <4 x float> undef, float %x, i32 0 + %vec = insertelement <4 x float> poison, float %x, i32 0 %approx = tail call <4 x float> @llvm.x86.sse.rsqrt.ss(<4 x float> %vec) %result = extractelement <4 x float> %approx, i32 0 ret float %result diff --git a/src/runtime/x86_avx2.ll b/src/runtime/x86_avx2.ll index a73736860682..d4d88be839c6 100644 --- a/src/runtime/x86_avx2.ll +++ b/src/runtime/x86_avx2.ll @@ -1,31 +1,31 @@ define weak_odr <16 x i16> @packssdwx16(<16 x i32> %arg) nounwind alwaysinline { - %1 = shufflevector <16 x i32> %arg, <16 x i32> undef, <8 x i32> - %2 = shufflevector <16 x i32> %arg, <16 x i32> undef, <8 x i32> + %1 = shufflevector <16 x i32> %arg, <16 x i32> poison, <8 x i32> + %2 = shufflevector <16 x i32> %arg, <16 x i32> poison, <8 x i32> %3 = tail call <16 x i16> @llvm.x86.avx2.packssdw(<8 x i32> %1, <8 x i32> %2) ret <16 x i16> %3 } declare <16 x i16> @llvm.x86.avx2.packssdw(<8 x i32>, <8 x i32>) define weak_odr <32 x i8> @packuswbx32(<32 x i16> %arg) nounwind alwaysinline { - %1 = shufflevector <32 x i16> %arg, <32 x i16> undef, <16 x i32> - %2 = shufflevector <32 x i16> %arg, <32 x i16> undef, <16 x i32> + %1 = shufflevector <32 x i16> %arg, <32 x i16> poison, <16 x i32> + %2 = shufflevector <32 x i16> %arg, <32 x i16> poison, <16 x i32> %3 = call <32 x i8> @llvm.x86.avx2.packuswb(<16 x i16> %1, <16 x i16> %2) ret <32 x i8> %3 } declare <32 x i8> @llvm.x86.avx2.packuswb(<16 x i16>, <16 x i16>) define weak_odr <32 x i8> @packsswbx32(<32 x i16> %arg) nounwind alwaysinline { - %1 = shufflevector <32 x i16> %arg, <32 x i16> undef, <16 x i32> - %2 = shufflevector <32 x i16> %arg, <32 x i16> undef, <16 x i32> + %1 = shufflevector <32 x i16> %arg, <32 x i16> poison, <16 x i32> + %2 = shufflevector <32 x i16> %arg, <32 x i16> poison, <16 x i32> %3 = call <32 x i8> @llvm.x86.avx2.packsswb(<16 x i16> %1, <16 x i16> %2) ret <32 x i8> %3 } declare <32 x i8> @llvm.x86.avx2.packsswb(<16 x i16>, <16 x i16>) define weak_odr <16 x i16> @packusdwx16(<16 x i32> %arg) nounwind alwaysinline { - %1 = shufflevector <16 x i32> %arg, <16 x i32> undef, <8 x i32> - %2 = shufflevector <16 x i32> %arg, <16 x i32> undef, <8 x i32> + %1 = shufflevector <16 x i32> %arg, <16 x i32> poison, <8 x i32> + %2 = shufflevector <16 x i32> %arg, <16 x i32> poison, <8 x i32> %3 = tail call <16 x i16> @llvm.x86.avx2.packusdw(<8 x i32> %1, <8 x i32> %2) ret <16 x i16> %3 } @@ -61,3 +61,14 @@ define weak_odr <16 x i16> @saturating_pmulhrswx16(<16 x i16> %a, <16 x i16> %b) ret <16 x i16> %5 } declare <16 x i16> @llvm.x86.avx2.pmul.hr.sw(<16 x i16>, <16 x i16>) nounwind readnone + +define weak_odr <16 x i16> @hadd_pmadd_u8_avx2(<32 x i8> %a) nounwind alwaysinline { + %1 = tail call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> %a, <32 x i8> ) + ret <16 x i16> %1 +} + +define weak_odr <16 x i16> @hadd_pmadd_i8_avx2(<32 x i8> %a) nounwind alwaysinline { + %1 = tail call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> , <32 x i8> %a) + ret <16 x i16> %1 +} +declare <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8>, <32 x i8>) nounwind readnone diff --git a/src/runtime/x86_avx512.ll b/src/runtime/x86_avx512.ll index 730014c99d9a..8cbc8abb9c5d 100644 --- a/src/runtime/x86_avx512.ll +++ b/src/runtime/x86_avx512.ll @@ -1,8 +1,8 @@ ; Split a 32 element f32 vector into two 16 element vectors to use the cvtne2ps2bf16 intrinsic. define weak_odr <32 x i16> @vcvtne2ps2bf16x32(<32 x float> %arg) nounwind alwaysinline { - %1 = shufflevector <32 x float> %arg, <32 x float> undef, <16 x i32> - %2 = shufflevector <32 x float> %arg, <32 x float> undef, <16 x i32> + %1 = shufflevector <32 x float> %arg, <32 x float> poison, <16 x i32> + %2 = shufflevector <32 x float> %arg, <32 x float> poison, <16 x i32> %3 = tail call <32 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float> %2, <16 x float> %1) ret <32 x i16> %3 } @@ -11,8 +11,8 @@ declare <32 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float>, <16 x fl ; LLVM does not have an unmasked version of cvtneps2bf16.128, so provide a wrapper around the masked version. define weak_odr <4 x i16> @vcvtneps2bf16x4(<4 x float> %arg) nounwind alwaysinline { - %1 = tail call <8 x i16> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %arg, <8 x i16> undef, <4 x i1> ) - %2 = shufflevector <8 x i16> %1, <8 x i16> undef, <4 x i32> + %1 = tail call <8 x i16> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %arg, <8 x i16> poison, <4 x i1> ) + %2 = shufflevector <8 x i16> %1, <8 x i16> poison, <4 x i32> ret <4 x i16> %2 } diff --git a/src/runtime/x86_sse41.ll b/src/runtime/x86_sse41.ll index 3ca654d0e874..d181de3d67e8 100644 --- a/src/runtime/x86_sse41.ll +++ b/src/runtime/x86_sse41.ll @@ -1,8 +1,8 @@ declare <8 x i16> @llvm.x86.sse41.packusdw(<4 x i32>, <4 x i32>) nounwind readnone define weak_odr <8 x i16> @packusdwx8(<8 x i32> %arg) nounwind alwaysinline { - %1 = shufflevector <8 x i32> %arg, <8 x i32> undef, <4 x i32> - %2 = shufflevector <8 x i32> %arg, <8 x i32> undef, <4 x i32> < i32 4, i32 5, i32 6, i32 7> + %1 = shufflevector <8 x i32> %arg, <8 x i32> poison, <4 x i32> + %2 = shufflevector <8 x i32> %arg, <8 x i32> poison, <4 x i32> < i32 4, i32 5, i32 6, i32 7> %3 = tail call <8 x i16> @llvm.x86.sse41.packusdw(<4 x i32> %1, <4 x i32> %2) ret <8 x i16> %3 } @@ -81,3 +81,14 @@ define weak_odr <8 x i16> @saturating_pmulhrswx8(<8 x i16> %a, <8 x i16> %b) nou ret <8 x i16> %5 } declare <8 x i16> @llvm.x86.ssse3.pmul.hr.sw.128(<8 x i16>, <8 x i16>) nounwind readnone + +define weak_odr <8 x i16> @hadd_pmadd_u8_sse3(<16 x i8> %a) nounwind alwaysinline { + %1 = tail call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> %a, <16 x i8> ) + ret <8 x i16> %1 +} + +define weak_odr <8 x i16> @hadd_pmadd_i8_sse3(<16 x i8> %a) nounwind alwaysinline { + %1 = tail call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> , <16 x i8> %a) + ret <8 x i16> %1 +} +declare <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8>, <16 x i8>) nounwind readnone diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index c5147af5adc7..ca1e3f46acf8 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -38,3 +38,26 @@ option(WITH_TEST_GENERATOR "Build generator tests" ON) if (WITH_TEST_GENERATOR) add_subdirectory(generator) endif () + +# FIXME: Disable the runtime tests for MSVC until we have a MS compatible header. +# +# The runtime tests include src/runtime/runtime_internal.h which was written +# to only support clang (GCC's front end is close enough it works fine as well). +# We originally setup the tests to compile with clang (in the same way as the actual +# runtime bitcode files), but that wasn't very clean and didn't integrate well with +# the other tests, so we switched to just using the native system compiler. +# Sadly MSVC isn't compatible with the current runtime_internal.h which would need +# some platform specific ifdefs for attributes and types that are causing compile +# errors. +# +cmake_dependent_option(WITH_TEST_RUNTIME "Build runtime tests" ON + "NOT MSVC" OFF) + +if (WITH_TEST_RUNTIME) + message(STATUS "Building internal runtime tests enabled") + add_subdirectory(runtime) +else () + message(STATUS "Building internal runtime tests disabled") +endif () + +# FIXME: failing_with_issue is dead code :) diff --git a/test/auto_schedule/CMakeLists.txt b/test/auto_schedule/CMakeLists.txt index 668175d51ab5..e4dc2b8e671b 100644 --- a/test/auto_schedule/CMakeLists.txt +++ b/test/auto_schedule/CMakeLists.txt @@ -1,9 +1,14 @@ if (TARGET Halide::Mullapudi2016) tests(GROUPS auto_schedule + SOURCES + extern.cpp + param.cpp + ARGS $) + + tests(GROUPS auto_schedule multithreaded SOURCES cost_function.cpp data_dependent.cpp - extern.cpp fibonacci.cpp histogram.cpp large_window.cpp @@ -11,7 +16,6 @@ if (TARGET Halide::Mullapudi2016) max_filter.cpp multi_output.cpp overlap.cpp - param.cpp reorder.cpp small_pure_update.cpp tile_vs_inline.cpp diff --git a/test/auto_schedule/cost_function.cpp b/test/auto_schedule/cost_function.cpp index 3eb027db98c8..5006674bc188 100644 --- a/test/auto_schedule/cost_function.cpp +++ b/test/auto_schedule/cost_function.cpp @@ -48,7 +48,11 @@ int main(int argc, char **argv) { // Auto-schedule the pipeline Target target = get_jit_target_from_environment(); Pipeline p(stencils[num_stencils - 1]); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API AutoSchedulerResults results = p.auto_schedule(target); +#else + AutoSchedulerResults results = p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif std::cout << "\n\n******************************************\nSCHEDULE:\n" << "******************************************\n" diff --git a/test/auto_schedule/data_dependent.cpp b/test/auto_schedule/data_dependent.cpp index 5a54626c4763..828a1061cd3e 100644 --- a/test/auto_schedule/data_dependent.cpp +++ b/test/auto_schedule/data_dependent.cpp @@ -40,7 +40,11 @@ int main(int argc, char **argv) { Target target = get_jit_target_from_environment(); Pipeline p(g); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif // Inspect the schedule g.print_loop_nest(); diff --git a/test/auto_schedule/extern.cpp b/test/auto_schedule/extern.cpp index e442d8aca902..8cd4b5181c2c 100644 --- a/test/auto_schedule/extern.cpp +++ b/test/auto_schedule/extern.cpp @@ -1,14 +1,8 @@ #include "Halide.h" #include -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - // An extern stage that translates. -extern "C" DLLEXPORT int translate(halide_buffer_t *in, int dx, int dy, halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int translate(halide_buffer_t *in, int dx, int dy, halide_buffer_t *out) { if (in->is_bounds_query()) { in->dim[0].min = out->dim[0].min + dx; @@ -58,7 +52,11 @@ void test_case_1() { Target target = get_jit_target_from_environment(); Pipeline p(g); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif // Inspect the schedule g.print_loop_nest(); @@ -88,7 +86,11 @@ void test_case_2() { Target target = get_jit_target_from_environment(); Pipeline p(g); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif // Inspect the schedule g.print_loop_nest(); @@ -120,7 +122,11 @@ void test_case_3() { Target target = get_jit_target_from_environment(); Pipeline p(g); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif // Inspect the schedule g.print_loop_nest(); diff --git a/test/auto_schedule/fibonacci.cpp b/test/auto_schedule/fibonacci.cpp index a394af50a921..0d2a05a3001b 100644 --- a/test/auto_schedule/fibonacci.cpp +++ b/test/auto_schedule/fibonacci.cpp @@ -22,7 +22,11 @@ double run_test(bool auto_schedule) { if (auto_schedule) { // Auto-schedule the pipeline +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif } // Inspect the schedule diff --git a/test/auto_schedule/histogram.cpp b/test/auto_schedule/histogram.cpp index c51cac7436b4..0cc4f151030b 100644 --- a/test/auto_schedule/histogram.cpp +++ b/test/auto_schedule/histogram.cpp @@ -64,7 +64,11 @@ double run_test(bool auto_schedule) { // Provide estimates on the pipeline output color.set_estimates({{0, 1920}, {0, 1024}, {0, 3}}); // Auto-schedule the pipeline +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif } else if (target.has_gpu_feature()) { Var xi("xi"), yi("yi"); Y.compute_root().gpu_tile(x, y, xi, yi, 16, 16); diff --git a/test/auto_schedule/large_window.cpp b/test/auto_schedule/large_window.cpp index 2626b9a2508b..c449d7136873 100644 --- a/test/auto_schedule/large_window.cpp +++ b/test/auto_schedule/large_window.cpp @@ -46,7 +46,11 @@ int main(int argc, char **argv) { Target target = get_jit_target_from_environment(); Pipeline p(g); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif // Inspect the schedule g.print_loop_nest(); diff --git a/test/auto_schedule/mat_mul.cpp b/test/auto_schedule/mat_mul.cpp index 07e5fefce2ca..73bac853d393 100644 --- a/test/auto_schedule/mat_mul.cpp +++ b/test/auto_schedule/mat_mul.cpp @@ -40,7 +40,11 @@ double run_test(bool auto_schedule) { // Provide estimates on the pipeline output out.set_estimate(x, 0, size).set_estimate(y, 0, size); // Auto-schedule the pipeline +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif } else if (target.has_gpu_feature()) { Var xi("xi"), yi("yi"), xii("xii"), yii("yii"), xt("xt"), yt("yt"); out.tile(x, y, xi, yi, 8, 8).unroll(xi).unroll(yi).gpu_tile(x, y, xt, yt, 8, 8); diff --git a/test/auto_schedule/max_filter.cpp b/test/auto_schedule/max_filter.cpp index fa9b72706d5d..f9d7e0854012 100644 --- a/test/auto_schedule/max_filter.cpp +++ b/test/auto_schedule/max_filter.cpp @@ -72,7 +72,11 @@ double run_test(bool auto_schedule) { .set_estimate(y, 0, in.height()) .set_estimate(c, 0, in.channels()); // Auto-schedule the pipeline +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif } else if (target.has_gpu_feature()) { slice_for_radius.compute_root(); filter_height.compute_root(); diff --git a/test/auto_schedule/multi_output.cpp b/test/auto_schedule/multi_output.cpp index f00f4ee09fa3..3ad372568e13 100644 --- a/test/auto_schedule/multi_output.cpp +++ b/test/auto_schedule/multi_output.cpp @@ -44,10 +44,14 @@ int main(int argc, char **argv) { std::vector outs; outs.push_back(h); outs.push_back(g); - Pipeline test(outs); + Pipeline p(outs); Target target = get_jit_target_from_environment(); - test.auto_schedule(target); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif // Inspect the schedule h.print_loop_nest(); @@ -56,7 +60,7 @@ int main(int argc, char **argv) { Buffer out_1(999, 999), out_2(999, 999); // Run the schedule - test.realize({out_1, out_2}); + p.realize({out_1, out_2}); printf("Success!\n"); return 0; diff --git a/test/auto_schedule/overlap.cpp b/test/auto_schedule/overlap.cpp index 8fe4a0b5aa1f..2f747879244f 100644 --- a/test/auto_schedule/overlap.cpp +++ b/test/auto_schedule/overlap.cpp @@ -50,7 +50,11 @@ int main(int argc, char **argv) { Target target = get_jit_target_from_environment(); Pipeline p(up[num_levels - 1]); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif // Inspect the schedule up[num_levels - 1].print_loop_nest(); diff --git a/test/auto_schedule/param.cpp b/test/auto_schedule/param.cpp index 1db0458d0e2f..7102e1d61217 100644 --- a/test/auto_schedule/param.cpp +++ b/test/auto_schedule/param.cpp @@ -23,7 +23,11 @@ void run_test_1() { Target target = get_jit_target_from_environment(); Pipeline p(g); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif // Inspect the schedule g.print_loop_nest(); @@ -50,7 +54,11 @@ void run_test_2() { Target target = get_jit_target_from_environment(); Pipeline p(g); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif // Inspect the schedule g.print_loop_nest(); @@ -77,7 +85,11 @@ void run_test_3() { Target target = get_jit_target_from_environment(); Pipeline p(output); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif // Inspect the schedule output.print_loop_nest(); @@ -107,7 +119,11 @@ void run_test_4() { Target target = get_jit_target_from_environment(); Pipeline p(output); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif // Inspect the schedule output.print_loop_nest(); diff --git a/test/auto_schedule/reorder.cpp b/test/auto_schedule/reorder.cpp index 24c4893051f7..ba15be2544aa 100644 --- a/test/auto_schedule/reorder.cpp +++ b/test/auto_schedule/reorder.cpp @@ -27,7 +27,11 @@ double run_test_1(bool auto_schedule) { // Provide estimates on the pipeline output r.set_estimates({{0, 1024}, {0, 1024}, {0, 3}}); // Auto-schedule the pipeline +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif } else { /* r.update(0).fuse(c, y, par).parallel(par).reorder(x, dom.x, dom.y).vectorize(x, 4); @@ -79,7 +83,11 @@ double run_test_2(bool auto_schedule) { // Provide estimates on the pipeline output diff.set_estimates({{0, left_im.width()}, {0, left_im.height()}, {0, 32}, {0, 3}}); // Auto-schedule the pipeline +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif } else { Var t("t"); diff.reorder(c, z).fuse(c, z, t).parallel(t).vectorize(x, 16); @@ -118,7 +126,11 @@ double run_test_3(bool auto_schedule) { // Provide estimates on the pipeline output r.set_estimates({{0, 1024}, {0, 1024}, {0, 3}}); // Auto-schedule the pipeline +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif } else { Var par("par"); r.update(0).fuse(c, y, par).parallel(par).reorder(x, dom.x, dom.y).vectorize(x, 4); diff --git a/test/auto_schedule/small_pure_update.cpp b/test/auto_schedule/small_pure_update.cpp index 4ef2649048ee..3954c257015a 100644 --- a/test/auto_schedule/small_pure_update.cpp +++ b/test/auto_schedule/small_pure_update.cpp @@ -28,8 +28,13 @@ int main(int argc, char **argv) { h.set_estimates({{0, 13}, {0, 17}}); in_param.set_estimates({{0, 13}, {0, 17}}); + Target target = get_target_from_environment(); Pipeline p(h); - p.auto_schedule(Target("host")); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif in_param.set(in); diff --git a/test/auto_schedule/tile_vs_inline.cpp b/test/auto_schedule/tile_vs_inline.cpp index 1c067cd81ab7..01ebaa15baca 100644 --- a/test/auto_schedule/tile_vs_inline.cpp +++ b/test/auto_schedule/tile_vs_inline.cpp @@ -44,7 +44,11 @@ int main(int argc, char **argv) { Target target = get_jit_target_from_environment(); Pipeline p(g); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif // Inspect the schedule g.print_loop_nest(); diff --git a/test/auto_schedule/unused_func.cpp b/test/auto_schedule/unused_func.cpp index bac796b6baa3..406ba438f0c9 100644 --- a/test/auto_schedule/unused_func.cpp +++ b/test/auto_schedule/unused_func.cpp @@ -28,7 +28,11 @@ int main(int argc, char **argv) { Target target = get_jit_target_from_environment(); Pipeline p(f); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif // Inspect the schedule f.print_loop_nest(); diff --git a/test/auto_schedule/vectorize_var_in_update.cpp b/test/auto_schedule/vectorize_var_in_update.cpp index 8b0f6881220f..13f9bf155bb9 100644 --- a/test/auto_schedule/vectorize_var_in_update.cpp +++ b/test/auto_schedule/vectorize_var_in_update.cpp @@ -50,7 +50,11 @@ int main(int argc, char **argv) { Target target = get_jit_target_from_environment(); Pipeline p(h); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif // Inspect the schedule h.print_loop_nest(); diff --git a/test/common/test_sharding.h b/test/common/test_sharding.h new file mode 100644 index 000000000000..6266dea0a387 --- /dev/null +++ b/test/common/test_sharding.h @@ -0,0 +1,89 @@ +#ifndef TEST_SHARDING_H +#define TEST_SHARDING_H + +// This file may be used by AOT tests, so it deliberately does not +// include Halide.h + +#include +#include +#include +#include +#include + +namespace Halide { +namespace Internal { +namespace Test { + +// Support the environment variables are used by the GoogleTest framework +// to allow a large test to be 'sharded' into smaller pieces: +// +// - If TEST_SHARD_STATUS_FILE is not empty, we should create a file at that path +// to indicate to the test framework that we support sharding. (Note that this +// must be done even if the test does a [SKIP] and executes no tests.) +// - If TEST_TOTAL_SHARDS and TEST_SHARD_INDEX are defined, we should +// split our work into TEST_TOTAL_SHARDS chunks, and only do the TEST_SHARD_INDEX-th +// chunk on this run. +// +// The Halide buildbots don't (yet) make use of these, but some downstream consumers do. + +class Sharder { + // returns empty string (not null) if env var not found + static std::string get_env(const char *v) { + const char *r = getenv(v); + if (!r) r = ""; + return r; + } + + // returns 0 if env var not found + static int32_t get_env_i32(const char *v) { + return std::atoi(get_env(v).c_str()); // 0 if not found + } + + const size_t total_shards, shard_index; + +public: + // Available publicly in case the test is skipped via [SKIP] -- + // even if the test runs nothing, we still need to write to this file + // (if requested) to avoid making the external test framework unhappy. + // (We don't need to call it when actually instantiating a Sharder.) + static void accept_sharded_status() { + std::string shard_status_file = get_env("TEST_SHARD_STATUS_FILE"); + if (!shard_status_file.empty()) { + std::ofstream f(shard_status_file, std::ios::out | std::ios::binary); + f << "sharder\n"; + f.flush(); + f.close(); + } + } + + explicit Sharder() + : total_shards(get_env_i32("TEST_TOTAL_SHARDS")), + shard_index(get_env_i32("TEST_SHARD_INDEX")) { + + accept_sharded_status(); + if (total_shards != 0) { + if (total_shards < 0 || shard_index < 0 || shard_index >= total_shards) { + std::cerr << "Illegal values for sharding: total " << total_shards << " current " << shard_index << "\n"; + exit(-1); + } + } + } + + bool should_run(size_t task_index) const { + if (total_shards > 0) { + return (task_index % total_shards) == shard_index; + } else { + return true; + } + } + + bool is_sharded() const { + return total_shards > 0; + } +}; + +} // namespace Test +} // namespace Internal +} // namespace Halide + +#endif // TEST_SHARDING_H diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 45e21c234db1..10a108553231 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -7,20 +7,15 @@ tests(GROUPS correctness SOURCES align_bounds.cpp argmax.cpp - assertion_failure_in_parallel_for.cpp - async.cpp - async_copy_chain.cpp async_device_copy.cpp - atomic_tuples.cpp - atomics.cpp autodiff.cpp bad_likely.cpp bit_counting.cpp bitwise_ops.cpp bool_compute_root_vectorize.cpp bound.cpp - bound_storage.cpp bound_small_allocations.cpp + bound_storage.cpp boundary_conditions.cpp bounds.cpp bounds_inference.cpp @@ -36,6 +31,10 @@ tests(GROUPS correctness bounds_query.cpp buffer_t.cpp c_function.cpp + callable.cpp + callable_errors.cpp + callable_generator.cpp + callable_typed.cpp cascaded_filters.cpp cast.cpp cast_handle.cpp @@ -51,8 +50,6 @@ tests(GROUPS correctness compute_at_reordered_update_stage.cpp compute_at_split_rvar.cpp compute_inside_guard.cpp - compute_outermost.cpp - compute_with.cpp compute_with_in.cpp compute_with_inlined.cpp computed_index.cpp @@ -60,7 +57,6 @@ tests(GROUPS correctness constant_expr.cpp constant_type.cpp constraints.cpp - convolution.cpp convolution_multiple_kernels.cpp cross_compilation.cpp cse_nan.cpp @@ -85,22 +81,21 @@ tests(GROUPS correctness div_by_zero.cpp dynamic_allocation_in_gpu_kernel.cpp dynamic_reduction_bounds.cpp + early_out.cpp embed_bitcode.cpp erf.cpp exception.cpp explicit_inline_reductions.cpp extern_bounds_inference.cpp extern_consumer.cpp - extern_consumer_tiled.cpp extern_error.cpp extern_output_expansion.cpp extern_partial.cpp extern_producer.cpp extern_reorder_storage.cpp extern_sort.cpp - extern_stage.cpp extern_stage_on_device.cpp - external_code.cpp + extract_concat_bits.cpp failed_unroll.cpp fast_trigonometric.cpp fibonacci.cpp @@ -112,10 +107,8 @@ tests(GROUPS correctness float16_t_neon_op_check.cpp for_each_element.cpp force_onto_stack.cpp - func_clone.cpp func_lifetime.cpp func_lifetime_2.cpp - func_wrapper.cpp fuse.cpp fuse_gpu_threads.cpp fused_where_inner_extent_is_zero.cpp @@ -170,7 +163,6 @@ tests(GROUPS correctness host_alignment.cpp image_io.cpp image_of_lists.cpp - image_wrapper.cpp implicit_args.cpp implicit_args_tests.cpp in_place.cpp @@ -184,7 +176,6 @@ tests(GROUPS correctness interleave.cpp interleave_rgb.cpp interleave_x.cpp - interpreter.cpp interval.cpp intrinsics.cpp introspection.cpp @@ -196,7 +187,6 @@ tests(GROUPS correctness lazy_convolution.cpp leak_device_memory.cpp left_shift_negative.cpp - legal_race_condition.cpp lerp.cpp let_in_rdom_bound.cpp likely.cpp @@ -205,48 +195,33 @@ tests(GROUPS correctness loop_invariant_extern_calls.cpp loop_level_generator_param.cpp lossless_cast.cpp - lots_of_dimensions.cpp lots_of_loop_invariants.cpp + low_bit_depth_noise.cpp make_struct.cpp many_dimensions.cpp many_small_extern_stages.cpp many_updates.cpp math.cpp median3x3.cpp - memoize.cpp memoize_cloned.cpp min_extent.cpp mod.cpp mul_div_mod.cpp multi_output_pipeline_with_bad_sizes.cpp - multi_pass_reduction.cpp multi_splits_with_diff_tail_strategies.cpp multi_way_select.cpp multipass_constraints.cpp multiple_outputs.cpp - multiple_outputs_extern.cpp - multiple_scatter.cpp mux.cpp - named_updates.cpp - nested_shiftinwards.cpp nested_tail_strategies.cpp newtons_method.cpp non_nesting_extern_bounds_query.cpp non_vector_aligned_embeded_buffer.cpp obscure_image_references.cpp - oddly_sized_output.cpp out_constraint.cpp out_of_memory.cpp output_larger_than_two_gigs.cpp - parallel.cpp - parallel_alloc.cpp - parallel_fork.cpp parallel_gpu_nested.cpp - parallel_nested.cpp - parallel_nested_1.cpp - parallel_scatter.cpp - parallel_reductions.cpp - parallel_rvar.cpp param.cpp param_map.cpp parameter_constraints.cpp @@ -266,25 +241,25 @@ tests(GROUPS correctness pseudostack_shares_slots.cpp python_extension_gen.cpp pytorch.cpp - random.cpp + realize_condition_depends_on_tuple.cpp realize_larger_than_two_gigs.cpp realize_over_shifted_domain.cpp reduction_chain.cpp + reduction_predicate_racing.cpp reduction_non_rectangular.cpp reduction_schedule.cpp register_shuffle.cpp - reorder_rvars.cpp reorder_storage.cpp require.cpp reschedule.cpp reuse_stack_alloc.cpp - rfactor.cpp round.cpp saturating_casts.cpp scatter.cpp set_custom_trace.cpp shadowed_bound.cpp shared_self_references.cpp + shift_by_unsigned_negated.cpp shifted_image.cpp side_effects.cpp simd_op_check.cpp @@ -310,12 +285,10 @@ tests(GROUPS correctness stmt_to_html.cpp storage_folding.cpp store_in.cpp - stream_compaction.cpp strict_float.cpp strict_float_bounds.cpp strided_load.cpp target.cpp - thread_safety.cpp tiled_matmul.cpp tracing.cpp tracing_bounds.cpp @@ -323,19 +296,17 @@ tests(GROUPS correctness tracing_stack.cpp transitive_bounds.cpp trim_no_ops.cpp - truncated_pyramid.cpp tuple_partial_update.cpp tuple_reduction.cpp tuple_select.cpp tuple_undef.cpp tuple_update_ops.cpp - tuple_vector_reduce.cpp two_vector_args.cpp + typed_func.cpp undef.cpp uninitialized_read.cpp unique_func_image.cpp unroll_dynamic_loop.cpp - unroll_huge_mux.cpp unrolled_reduction.cpp unsafe_dedup_lets.cpp unsafe_promises.cpp @@ -360,6 +331,49 @@ tests(GROUPS correctness widening_reduction.cpp ) +tests(GROUPS correctness multithreaded + SOURCES + assertion_failure_in_parallel_for.cpp + async.cpp + async_copy_chain.cpp + atomic_tuples.cpp + atomics.cpp + compute_outermost.cpp + compute_with.cpp + convolution.cpp + extern_consumer_tiled.cpp + extern_stage.cpp + func_clone.cpp + func_wrapper.cpp + image_wrapper.cpp + interpreter.cpp + legal_race_condition.cpp + lots_of_dimensions.cpp + memoize.cpp + multi_pass_reduction.cpp + multiple_outputs_extern.cpp + multiple_scatter.cpp + named_updates.cpp + nested_shiftinwards.cpp + oddly_sized_output.cpp + parallel.cpp + parallel_alloc.cpp + parallel_fork.cpp + parallel_nested.cpp + parallel_nested_1.cpp + parallel_reductions.cpp + parallel_rvar.cpp + parallel_scatter.cpp + random.cpp + reorder_rvars.cpp + rfactor.cpp + stream_compaction.cpp + thread_safety.cpp + truncated_pyramid.cpp + tuple_vector_reduce.cpp + unroll_huge_mux.cpp + ) + # Make sure the test that needs image_io has it target_link_libraries(correctness_image_io PRIVATE Halide::ImageIO) @@ -367,6 +381,9 @@ target_link_libraries(correctness_image_io PRIVATE Halide::ImageIO) set_target_properties(correctness_async correctness_atomics correctness_c_function + correctness_callable + correctness_callable_generator + correctness_callable_typed correctness_compute_at_split_rvar correctness_concat correctness_custom_lowering_pass @@ -403,4 +420,3 @@ set_target_properties(correctness_async correctness_sliding_window correctness_storage_folding PROPERTIES ENABLE_EXPORTS TRUE) - diff --git a/test/correctness/async.cpp b/test/correctness/async.cpp index b82c01e32800..97278a212bba 100644 --- a/test/correctness/async.cpp +++ b/test/correctness/async.cpp @@ -2,13 +2,7 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - -extern "C" DLLEXPORT int expensive(int x) { +extern "C" HALIDE_EXPORT_SYMBOL int expensive(int x) { float f = 3.0f; for (int i = 0; i < (1 << 10); i++) { f = sqrtf(sinf(cosf(f))); diff --git a/test/correctness/atomic_tuples.cpp b/test/correctness/atomic_tuples.cpp index dc55e3920930..63a31a798157 100644 --- a/test/correctness/atomic_tuples.cpp +++ b/test/correctness/atomic_tuples.cpp @@ -213,7 +213,7 @@ int main(int argc, char **argv) { if (out(x, y) != correct) { printf("out(%d, %d) = %d instead of %d\n", x, y, out(x, y), correct); - //return -1; + // return -1; } } } diff --git a/test/correctness/atomics.cpp b/test/correctness/atomics.cpp index 9c1547d1be78..f6f9d459a701 100644 --- a/test/correctness/atomics.cpp +++ b/test/correctness/atomics.cpp @@ -4,12 +4,6 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - enum class Backend { CPU, CPUVectorize, @@ -1008,7 +1002,7 @@ void test_all(const Backend &backend) { } } -extern "C" DLLEXPORT int extern_func(int x) { +extern "C" HALIDE_EXPORT_SYMBOL int extern_func(int x) { return x + 1; } HalideExtern_1(int, extern_func, int); @@ -1060,7 +1054,7 @@ void test_extern_func(const Backend &backend) { } } -extern "C" DLLEXPORT int expensive(int x) { +extern "C" HALIDE_EXPORT_SYMBOL int expensive(int x) { float f = 3.0f; for (int i = 0; i < (1 << 10); i++) { f = sqrtf(sinf(cosf(f))); diff --git a/test/correctness/boundary_conditions.cpp b/test/correctness/boundary_conditions.cpp index f80698a22db7..697678cf28e8 100644 --- a/test/correctness/boundary_conditions.cpp +++ b/test/correctness/boundary_conditions.cpp @@ -1,7 +1,7 @@ #include "Halide.h" -#include -#include +#include "test_sharding.h" +#include #include using namespace Halide; @@ -31,7 +31,7 @@ void schedule_test(Func f, int vector_width, const Target &t) { f.gpu_tile(x, y, xo, yo, xi, yi, 2, 2); } else if (t.has_feature(Target::HVX)) { // TODO: Non-native vector widths hang the compiler here. - //f.hexagon(); + // f.hexagon(); } } @@ -180,9 +180,11 @@ bool check_mirror_interior(const Buffer &input, Func f, return success; } -bool test_all(int vector_width, Target t) { - bool success = true; +struct Task { + std::function fn; +}; +void add_all(int vector_width, Target t, std::vector &tasks) { const int W = 32; const int H = 32; Buffer input(W, H); @@ -201,34 +203,34 @@ bool test_all(int vector_width, Target t) { const int32_t test_extent = 100; // Func input. - success &= check_repeat_edge( - input, - repeat_edge(input_f, {{0, W}, {0, H}}), - test_min, test_extent, test_min, test_extent, - vector_width, t); + tasks.push_back({[=]() { return check_repeat_edge( + input, + repeat_edge(input_f, {{0, W}, {0, H}}), + test_min, test_extent, test_min, test_extent, + vector_width, t); }}); // Image input. - success &= check_repeat_edge( - input, - repeat_edge(input, {{0, W}, {0, H}}), - test_min, test_extent, test_min, test_extent, - vector_width, t); + tasks.push_back({[=]() { return check_repeat_edge( + input, + repeat_edge(input, {{0, W}, {0, H}}), + test_min, test_extent, test_min, test_extent, + vector_width, t); }}); // Undefined bounds. - success &= check_repeat_edge( - input, - repeat_edge(input, {{Expr(), Expr()}, {0, H}}), - 0, W, test_min, test_extent, - vector_width, t); - success &= check_repeat_edge( - input, - repeat_edge(input, {{0, W}, {Expr(), Expr()}}), - test_min, test_extent, 0, H, - vector_width, t); + tasks.push_back({[=]() { return check_repeat_edge( + input, + repeat_edge(input, {{Expr(), Expr()}, {0, H}}), + 0, W, test_min, test_extent, + vector_width, t); }}); + tasks.push_back({[=]() { return check_repeat_edge( + input, + repeat_edge(input, {{0, W}, {Expr(), Expr()}}), + test_min, test_extent, 0, H, + vector_width, t); }}); // Implicitly determined bounds. - success &= check_repeat_edge( - input, - repeat_edge(input), - test_min, test_extent, test_min, test_extent, - vector_width, t); + tasks.push_back({[=]() { return check_repeat_edge( + input, + repeat_edge(input), + test_min, test_extent, test_min, test_extent, + vector_width, t); }}); } // constant_exterior: @@ -239,34 +241,34 @@ bool test_all(int vector_width, Target t) { const uint8_t exterior = 42; // Func input. - success &= check_constant_exterior( - input, exterior, - constant_exterior(input_f, exterior, {{0, W}, {0, H}}), - test_min, test_extent, test_min, test_extent, - vector_width, t); + tasks.push_back({[=]() { return check_constant_exterior( + input, exterior, + constant_exterior(input_f, exterior, {{0, W}, {0, H}}), + test_min, test_extent, test_min, test_extent, + vector_width, t); }}); // Image input. - success &= check_constant_exterior( - input, exterior, - constant_exterior(input, exterior, {{0, W}, {0, H}}), - test_min, test_extent, test_min, test_extent, - vector_width, t); + tasks.push_back({[=]() { return check_constant_exterior( + input, exterior, + constant_exterior(input, exterior, {{0, W}, {0, H}}), + test_min, test_extent, test_min, test_extent, + vector_width, t); }}); // Undefined bounds. - success &= check_constant_exterior( - input, exterior, - constant_exterior(input, exterior, {{Expr(), Expr()}, {0, H}}), - 0, W, test_min, test_extent, - vector_width, t); - success &= check_constant_exterior( - input, exterior, - constant_exterior(input, exterior, {{0, W}, {Expr(), Expr()}}), - test_min, test_extent, 0, H, - vector_width, t); + tasks.push_back({[=]() { return check_constant_exterior( + input, exterior, + constant_exterior(input, exterior, {{Expr(), Expr()}, {0, H}}), + 0, W, test_min, test_extent, + vector_width, t); }}); + tasks.push_back({[=]() { return check_constant_exterior( + input, exterior, + constant_exterior(input, exterior, {{0, W}, {Expr(), Expr()}}), + test_min, test_extent, 0, H, + vector_width, t); }}); // Implicitly determined bounds. - success &= check_constant_exterior( - input, exterior, - constant_exterior(input, exterior), - test_min, test_extent, test_min, test_extent, - vector_width, t); + tasks.push_back({[=]() { return check_constant_exterior( + input, exterior, + constant_exterior(input, exterior), + test_min, test_extent, test_min, test_extent, + vector_width, t); }}); } // repeat_image: @@ -275,34 +277,34 @@ bool test_all(int vector_width, Target t) { const int32_t test_extent = 100; // Func input. - success &= check_repeat_image( - input, - repeat_image(input_f, {{0, W}, {0, H}}), - test_min, test_extent, test_min, test_extent, - vector_width, t); + tasks.push_back({[=]() { return check_repeat_image( + input, + repeat_image(input_f, {{0, W}, {0, H}}), + test_min, test_extent, test_min, test_extent, + vector_width, t); }}); // Image input. - success &= check_repeat_image( - input, - repeat_image(input, {{0, W}, {0, H}}), - test_min, test_extent, test_min, test_extent, - vector_width, t); + tasks.push_back({[=]() { return check_repeat_image( + input, + repeat_image(input, {{0, W}, {0, H}}), + test_min, test_extent, test_min, test_extent, + vector_width, t); }}); // Undefined bounds. - success &= check_repeat_image( - input, - repeat_image(input, {{Expr(), Expr()}, {0, H}}), - 0, W, test_min, test_extent, - vector_width, t); - success &= check_repeat_image( - input, - repeat_image(input, {{0, W}, {Expr(), Expr()}}), - test_min, test_extent, 0, H, - vector_width, t); + tasks.push_back({[=]() { return check_repeat_image( + input, + repeat_image(input, {{Expr(), Expr()}, {0, H}}), + 0, W, test_min, test_extent, + vector_width, t); }}); + tasks.push_back({[=]() { return check_repeat_image( + input, + repeat_image(input, {{0, W}, {Expr(), Expr()}}), + test_min, test_extent, 0, H, + vector_width, t); }}); // Implicitly determined bounds. - success &= check_repeat_image( - input, - repeat_image(input), - test_min, test_extent, test_min, test_extent, - vector_width, t); + tasks.push_back({[=]() { return check_repeat_image( + input, + repeat_image(input), + test_min, test_extent, test_min, test_extent, + vector_width, t); }}); } // mirror_image: @@ -311,34 +313,34 @@ bool test_all(int vector_width, Target t) { const int32_t test_extent = 100; // Func input. - success &= check_mirror_image( - input, - mirror_image(input_f, {{0, W}, {0, H}}), - test_min, test_extent, test_min, test_extent, - vector_width, t); + tasks.push_back({[=]() { return check_mirror_image( + input, + mirror_image(input_f, {{0, W}, {0, H}}), + test_min, test_extent, test_min, test_extent, + vector_width, t); }}); // Image input. - success &= check_mirror_image( - input, - mirror_image(input, {{0, W}, {0, H}}), - test_min, test_extent, test_min, test_extent, - vector_width, t); + tasks.push_back({[=]() { return check_mirror_image( + input, + mirror_image(input, {{0, W}, {0, H}}), + test_min, test_extent, test_min, test_extent, + vector_width, t); }}); // Undefined bounds. - success &= check_mirror_image( - input, - mirror_image(input, {{Expr(), Expr()}, {0, H}}), - 0, W, test_min, test_extent, - vector_width, t); - success &= check_mirror_image( - input, - mirror_image(input, {{0, W}, {Expr(), Expr()}}), - test_min, test_extent, 0, H, - vector_width, t); + tasks.push_back({[=]() { return check_mirror_image( + input, + mirror_image(input, {{Expr(), Expr()}, {0, H}}), + 0, W, test_min, test_extent, + vector_width, t); }}); + tasks.push_back({[=]() { return check_mirror_image( + input, + mirror_image(input, {{0, W}, {Expr(), Expr()}}), + test_min, test_extent, 0, H, + vector_width, t); }}); // Implicitly determined bounds. - success &= check_mirror_image( - input, - mirror_image(input), - test_min, test_extent, test_min, test_extent, - vector_width, t); + tasks.push_back({[=]() { return check_mirror_image( + input, + mirror_image(input), + test_min, test_extent, test_min, test_extent, + vector_width, t); }}); } // mirror_interior: @@ -347,44 +349,40 @@ bool test_all(int vector_width, Target t) { const int32_t test_extent = 100; // Func input. - success &= check_mirror_interior( - input, - mirror_interior(input_f, {{0, W}, {0, H}}), - test_min, test_extent, test_min, test_extent, - vector_width, t); + tasks.push_back({[=]() { return check_mirror_interior( + input, + mirror_interior(input_f, {{0, W}, {0, H}}), + test_min, test_extent, test_min, test_extent, + vector_width, t); }}); // Image input. - success &= check_mirror_interior( - input, - mirror_interior(input, {{0, W}, {0, H}}), - test_min, test_extent, test_min, test_extent, - vector_width, t); + tasks.push_back({[=]() { return check_mirror_interior( + input, + mirror_interior(input, {{0, W}, {0, H}}), + test_min, test_extent, test_min, test_extent, + vector_width, t); }}); // Undefined bounds. - success &= check_mirror_interior( - input, - mirror_interior(input, {{Expr(), Expr()}, {0, H}}), - 0, W, test_min, test_extent, - vector_width, t); - success &= check_mirror_interior( - input, - mirror_interior(input, {{0, W}, {Expr(), Expr()}}), - test_min, test_extent, 0, H, - vector_width, t); + tasks.push_back({[=]() { return check_mirror_interior( + input, + mirror_interior(input, {{Expr(), Expr()}, {0, H}}), + 0, W, test_min, test_extent, + vector_width, t); }}); + tasks.push_back({[=]() { return check_mirror_interior( + input, + mirror_interior(input, {{0, W}, {Expr(), Expr()}}), + test_min, test_extent, 0, H, + vector_width, t); }}); // Implicitly determined bounds. - success &= check_mirror_interior( - input, - mirror_interior(input), - test_min, test_extent, test_min, test_extent, - vector_width, t); + tasks.push_back({[=]() { return check_mirror_interior( + input, + mirror_interior(input), + test_min, test_extent, test_min, test_extent, + vector_width, t); }}); } - - return success; } int main(int argc, char **argv) { Target target = get_jit_target_from_environment(); - Halide::Internal::ThreadPool pool; - std::vector> futures; int vector_width_max = 32; if (target.has_feature(Target::Metal) || target.has_feature(Target::OpenGLCompute) || @@ -396,24 +394,20 @@ int main(int argc, char **argv) { // The wasm jit is very slow, so shorten this test here. vector_width_max = 8; } - for (int vector_width = 1; vector_width <= vector_width_max; vector_width *= 2) { - std::cout << "Testing vector_width: " << vector_width << "\n"; - if (target.has_feature(Target::OpenGLCompute)) { - // GL can't be used from multiple threads at once - test_all(vector_width, target); - } else { - futures.push_back(pool.async(test_all, vector_width, target)); - } - } - bool success = true; - for (auto &f : futures) { - success &= f.get(); + std::vector tasks; + for (int vector_width = 1; vector_width <= vector_width_max; vector_width *= 2) { + add_all(vector_width, target, tasks); } - if (!success) { - fprintf(stderr, "Failed!\n"); - return -1; + using Sharder = Halide::Internal::Test::Sharder; + Sharder sharder; + for (size_t t = 0; t < tasks.size(); t++) { + if (!sharder.should_run(t)) continue; + const auto &task = tasks.at(t); + if (!task.fn()) { + exit(-1); + } } printf("Success!\n"); diff --git a/test/correctness/bounds_inference_chunk.cpp b/test/correctness/bounds_inference_chunk.cpp index 74345b72b56b..63d4e8b12751 100644 --- a/test/correctness/bounds_inference_chunk.cpp +++ b/test/correctness/bounds_inference_chunk.cpp @@ -15,7 +15,7 @@ int main(int argc, char **argv) { h.compute_root(); g.compute_at(f, y); - //f.trace(); + // f.trace(); Buffer out = f.realize({32, 32}); diff --git a/test/correctness/c_function.cpp b/test/correctness/c_function.cpp index f83f2d66eff1..cadab3386f7f 100644 --- a/test/correctness/c_function.cpp +++ b/test/correctness/c_function.cpp @@ -7,27 +7,21 @@ using namespace Halide; // This is not supported by the C backend. // On windows, you need to use declspec to do the same. -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - int call_counter = 0; -extern "C" DLLEXPORT float my_func(int x, float y) { +extern "C" HALIDE_EXPORT_SYMBOL float my_func(int x, float y) { call_counter++; return x * y; } HalideExtern_2(float, my_func, int, float); int call_counter2 = 0; -extern "C" DLLEXPORT float my_func2(int x, float y) { +extern "C" HALIDE_EXPORT_SYMBOL float my_func2(int x, float y) { call_counter2++; return x * y; } int call_counter3 = 0; -extern "C" DLLEXPORT float my_func3(int x, float y) { +extern "C" HALIDE_EXPORT_SYMBOL float my_func3(int x, float y) { call_counter3++; return x * y; } diff --git a/test/correctness/callable.cpp b/test/correctness/callable.cpp new file mode 100644 index 000000000000..e3ab207aa025 --- /dev/null +++ b/test/correctness/callable.cpp @@ -0,0 +1,224 @@ +#include "Halide.h" +#include +#include + +using namespace Halide; + +namespace { + +void check(int r) { + assert(r == 0); +} + +bool custom_malloc_called = false; +bool custom_free_called = false; + +void *my_malloc(JITUserContext *user_context, size_t x) { + custom_malloc_called = true; + void *orig = malloc(x + 32); + void *ptr = (void *)((((size_t)orig + 32) >> 5) << 5); + ((void **)ptr)[-1] = orig; + return ptr; +} + +void my_free(JITUserContext *user_context, void *ptr) { + custom_free_called = true; + free(((void **)ptr)[-1]); +} + +void *mischievous_malloc(JITUserContext *user_context, size_t x) { + fprintf(stderr, "This should never get called\n"); + abort(); + return nullptr; +} + +int call_counter = 0; +extern "C" HALIDE_EXPORT_SYMBOL float my_extern_func(int x, float y) { + call_counter++; + return x * y; +} +HalideExtern_2(float, my_extern_func, int, float); + +} // namespace + +int main(int argc, char **argv) { + const Target t = get_jit_target_from_environment(); + + { + Param p_int(42); + Param p_float(1.0f); + ImageParam p_img(UInt(8), 2); + + Var x("x"), y("y"); + Func f("f"); + + f(x, y) = p_img(x, y) + cast(p_int / p_float); + + Buffer in1(10, 10); + Buffer in2(10, 10); + + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + in1(i, j) = i + j * 10; + in2(i, j) = i * 10 + j; + } + } + + Callable c = f.compile_to_callable({p_img, p_int, p_float}, t); + + { + Buffer out1(10, 10); + check(c(in1, 42, 1.0f, out1)); + + Buffer out2(10, 10); + check(c(in2, 22, 2.0f, out2)); + + Buffer out3(10, 10); + check(c(in1, 12, 1.0f, out3)); + + Buffer out4(10, 10); + check(c(in2, 16, 1.0f, out4)); + + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + assert(out1(i, j) == i + j * 10 + 42); + assert(out2(i, j) == i * 10 + j + 11); + assert(out3(i, j) == i + j * 10 + 12); + assert(out4(i, j) == i * 10 + j + 16); + } + } + } + + { + // Test bounds inference + Buffer in_bounds(nullptr, 1, 1); + Buffer out_bounds(nullptr, 20, 20); + + check(c(in_bounds, 42, 1.0f, out_bounds)); + + assert(in_bounds.defined()); + assert(in_bounds.dim(0).extent() == 20); + assert(in_bounds.dim(1).extent() == 20); + assert(in1.dim(0).extent() == 10); + assert(in1.dim(1).extent() == 10); + } + } + + // Override Halide's malloc and free (except under wasm), + // make sure that Callable freezes the values + if (t.arch != Target::WebAssembly) { + custom_malloc_called = false; + custom_free_called = false; + + Func f, g; + Var x; + + f(x) = x; + g(x) = f(x); + f.compute_root(); + + g.jit_handlers().custom_malloc = my_malloc; + g.jit_handlers().custom_free = my_free; + + Callable c = g.compile_to_callable({}); + + // Changing g's handlers shouldn't affect any existing Callables + g.jit_handlers().custom_malloc = mischievous_malloc; + + Buffer im(100000); + check(c(im)); + + assert(custom_malloc_called); + assert(custom_free_called); + } + + // Check that Param works with Callables + if (t.arch != Target::WebAssembly) { + Func f("f"), g("g"); + Var x("x"); + Param handle("handle"); + + f(x) = reinterpret(handle); + + g(x) = reinterpret(handle); + g.vectorize(x, 4); + + Callable cf = f.compile_to_callable({handle}); + Callable cg = g.compile_to_callable({handle}); + + int foo = 0; + + Buffer out1(4); + // Create a dummy JITUserContext here just to test that + // passing one explicitly works correctly. + JITUserContext empty; + check(cf(&empty, &foo, out1)); + + Buffer out2(4); + check(cg(&foo, out2)); + + uint64_t correct = (uint64_t)((uintptr_t)(&foo)); + + for (int x = 0; x < out1.width(); x++) { + if (out1(x) != correct) { + printf("out1(%d) = %llu instead of %llu\n", + x, + (long long unsigned)out1(x), + (long long unsigned)correct); + exit(-1); + } + if (out2(x) != correct) { + printf("out2(%d) = %llu instead of %llu\n", + x, + (long long unsigned)out2(x), + (long long unsigned)correct); + exit(-1); + } + } + } + + // Check that JITExterns works with Callables + if (t.arch != Target::WebAssembly) { + call_counter = 0; + + std::vector args; + args.push_back(user_context_value()); + + Var x, y; + Func monitor; + monitor(x, y) = my_extern_func(x, cast(y)); + + Func f; + f.define_extern("extern_func", args, Float(32), 2); + + Pipeline p(f); + p.set_jit_externs({{"extern_func", JITExtern{monitor}}}); + + Callable c = p.compile_to_callable({}); + + // Changing g's jit_externs shouldn't affect any existing Callables + p.set_jit_externs({}); + + Buffer imf(32, 32); + check(c(imf)); + + // Check the result was what we expected + for (int i = 0; i < 32; i++) { + for (int j = 0; j < 32; j++) { + float correct = (float)(i * j); + float delta = imf(i, j) - correct; + if (delta < -0.001 || delta > 0.001) { + printf("imf[%d, %d] = %f instead of %f\n", i, j, imf(i, j), correct); + exit(-1); + } + } + } + + if (call_counter != 32 * 32) { + printf("In pipeline_set_jit_externs_func, my_func was called %d times instead of %d\n", call_counter, 32 * 32); + exit(-1); + } + } + + printf("Success!\n"); +} diff --git a/test/correctness/callable_errors.cpp b/test/correctness/callable_errors.cpp new file mode 100644 index 000000000000..2c0abce5407f --- /dev/null +++ b/test/correctness/callable_errors.cpp @@ -0,0 +1,221 @@ +#include "Halide.h" +#include +#include + +using namespace Halide; + +namespace { + +std::string error_msg; +void my_error(JITUserContext *ucon, const char *msg) { + error_msg = msg; +} + +void expect_failure(int r, const char *expected_msg) { + if (r == 0) { + std::cerr << "Expected failure, got success\n"; + exit(1); + } + if (!strstr(error_msg.c_str(), expected_msg)) { + std::cerr << "Expected error containing (" << expected_msg << "), but got (" << error_msg << ")\n"; + exit(1); + } + std::cout << "Saw expected: (" << expected_msg << ")\n"; + error_msg = ""; +} + +void expect_success(int r) { + if (r != 0) { + std::cerr << "Expected success, got failure\n"; + exit(1); + } + if (!error_msg.empty()) { + std::cerr << "Expected NO ERROR, got (" << error_msg << ")\n"; + exit(1); + } + std::cout << "Saw expected: (NO ERROR)\n"; + error_msg = ""; +} + +void test_bad_untyped_calls() { + // Test custom error handler in the JITHandler + { + Param p_int("p_int"); + Param p_float("p_float"); + ImageParam p_img(UInt(8), 2, "p_img"); + + Var x("x"), y("y"); + Func f("fn1"); + + f(x, y) = p_img(x, y) + cast(p_int / p_float); + + f.jit_handlers().custom_error = my_error; + + Callable c = f.compile_to_callable({p_img, p_int, p_float}); + + Buffer in1(10, 10), result1(10, 10); + in1.fill(0); + + expect_success(c(in1, 2, 1.0f, result1)); + expect_failure(c((const halide_buffer_t *)nullptr, 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c((halide_buffer_t *)nullptr, 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c(Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c(Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c(Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c(Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c(Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c(Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c(Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c(Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c(42, 2, 1.0f, result1), "Argument 1 of 4 ('p_img') was expected to be a buffer of type 'uint8' and dimension 2"); + expect_failure(c(in1, 2.25, 1.0f, result1), "Argument 2 of 4 ('p_int') was expected to be a scalar of type 'int32' and dimension 0"); + expect_failure(c(in1, 2, 1, result1), "Argument 3 of 4 ('p_float') was expected to be a scalar of type 'float32' and dimension 0"); + expect_failure(c(in1, 2, 1.0f, (const halide_buffer_t *)nullptr), "Buffer argument fn1 is nullptr"); + expect_failure(c(in1, 2, 1.0f, (halide_buffer_t *)nullptr), "Buffer argument fn1 is nullptr"); + expect_failure(c(in1, 2, 1.0f, Buffer()), "Buffer argument fn1 is nullptr"); + expect_failure(c(in1, 2, 1.0f, Buffer()), "Buffer argument fn1 is nullptr"); + expect_failure(c(in1, 2, 1.0f, Buffer()), "Buffer argument fn1 is nullptr"); + expect_failure(c(in1, 2, 1.0f, Buffer()), "Buffer argument fn1 is nullptr"); + expect_failure(c(in1, 2, 1.0f, Buffer()), "Buffer argument fn1 is nullptr"); + expect_failure(c(in1, 2, 1.0f, Buffer()), "Buffer argument fn1 is nullptr"); + expect_failure(c(in1, 2, 1.0f, Buffer()), "Buffer argument fn1 is nullptr"); + expect_failure(c(in1, 2, 1.0f, Buffer()), "Buffer argument fn1 is nullptr"); + } + + // Test custom error handler in the JITUserContext + { + Param p_int("p_int"); + Param p_float("p_float"); + ImageParam p_img(UInt(8), 2, "p_img"); + + Var x("x"), y("y"); + Func f("fn2"); + + f(x, y) = p_img(x, y) + cast(p_int / p_float); + + Callable c = f.compile_to_callable({p_img, p_int, p_float}); + + Buffer in1(10, 10), result1(10, 10); + in1.fill(0); + + JITUserContext context; + context.handlers.custom_error = my_error; + + expect_success(c(&context, in1, 2, 1.0f, result1)); + expect_failure(c(&context, (const halide_buffer_t *)nullptr, 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c(&context, (halide_buffer_t *)nullptr, 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c(&context, Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c(&context, Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c(&context, Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c(&context, Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c(&context, Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c(&context, Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c(&context, Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c(&context, Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c(&context, 42, 2, 1.0f, result1), "Argument 1 of 4 ('p_img') was expected to be a buffer of type 'uint8' and dimension 2"); + expect_failure(c(&context, in1, 2.25, 1.0f, result1), "Argument 2 of 4 ('p_int') was expected to be a scalar of type 'int32' and dimension 0"); + expect_failure(c(&context, in1, 2, 1, result1), "Argument 3 of 4 ('p_float') was expected to be a scalar of type 'float32' and dimension 0"); + expect_failure(c(&context, in1, 2, 1.0f, (const halide_buffer_t *)nullptr), "Buffer argument fn2 is nullptr"); + expect_failure(c(&context, in1, 2, 1.0f, (halide_buffer_t *)nullptr), "Buffer argument fn2 is nullptr"); + expect_failure(c(&context, in1, 2, 1.0f, Buffer()), "Buffer argument fn2 is nullptr"); + expect_failure(c(&context, in1, 2, 1.0f, Buffer()), "Buffer argument fn2 is nullptr"); + expect_failure(c(&context, in1, 2, 1.0f, Buffer()), "Buffer argument fn2 is nullptr"); + expect_failure(c(&context, in1, 2, 1.0f, Buffer()), "Buffer argument fn2 is nullptr"); + expect_failure(c(&context, in1, 2, 1.0f, Buffer()), "Buffer argument fn2 is nullptr"); + expect_failure(c(&context, in1, 2, 1.0f, Buffer()), "Buffer argument fn2 is nullptr"); + expect_failure(c(&context, in1, 2, 1.0f, Buffer()), "Buffer argument fn2 is nullptr"); + expect_failure(c(&context, in1, 2, 1.0f, Buffer()), "Buffer argument fn2 is nullptr"); + } +} + +void test_bad_typed_calls() { + // Test custom error handler in the JITHandler + { + Param p_int("p_int"); + Param p_float("p_float"); + ImageParam p_img(UInt(8), 2, "p_img"); + + Var x("x"), y("y"); + Func f("fn3"); + + f(x, y) = p_img(x, y) + cast(p_int / p_float); + + f.jit_handlers().custom_error = my_error; + + Callable c = f.compile_to_callable({p_img, p_int, p_float}); + + Buffer in1(10, 10), result1(10, 10); + in1.fill(0); + + auto c_typed = c.make_std_function, int32_t, float, Buffer>(); + expect_success(c_typed(in1, 2, 1.0f, result1)); + + // make_std_function succeeds, but calls to it fail at runtime + expect_failure(c_typed(Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c_typed(Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c_typed(Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c_typed(Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c_typed(in1, 2, 1.0f, Buffer()), "Buffer argument fn3 is nullptr"); + expect_failure(c_typed(in1, 2, 1.0f, Buffer()), "Buffer argument fn3 is nullptr"); + expect_failure(c_typed(in1, 2, 1.0f, Buffer()), "Buffer argument fn3 is nullptr"); + expect_failure(c_typed(in1, 2, 1.0f, Buffer()), "Buffer argument fn3 is nullptr"); + + // Calls to make_std_function fail + c.make_std_function>(); + expect_failure(-1, "Argument 1 of 4 ('p_img') was expected to be a buffer of type 'uint8' and dimension 2"); + + c.make_std_function, bool, float, Buffer>(); + expect_failure(-1, "Argument 2 of 4 ('p_int') was expected to be a scalar of type 'int32' and dimension 0"); + + c.make_std_function, int32_t, bool, Buffer>(); + expect_failure(-1, "Argument 3 of 4 ('p_float') was expected to be a scalar of type 'float32' and dimension 0"); + + c.make_std_function, int32_t, float, bool>(); + expect_failure(-1, "Argument 4 of 4 ('fn3') was expected to be a buffer of type 'uint8' and dimension 2"); + } + + // Test custom error handler in the JITUserContext + { + Param p_int("p_int"); + Param p_float("p_float"); + ImageParam p_img(UInt(8), 2, "p_img"); + + Var x("x"), y("y"); + Func f("fn4"); + + f(x, y) = p_img(x, y) + cast(p_int / p_float); + + Callable c = f.compile_to_callable({p_img, p_int, p_float}); + + Buffer in1(10, 10), result1(10, 10); + in1.fill(0); + + JITUserContext context; + context.handlers.custom_error = my_error; + + auto c_typed = c.make_std_function, int32_t, float, Buffer>(); + expect_success(c_typed(&context, in1, 2, 1.0f, result1)); + + // make_std_function succeeds, but calls to it fail at runtime + expect_failure(c_typed(&context, Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c_typed(&context, Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c_typed(&context, Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c_typed(&context, Buffer(), 2, 1.0f, result1), "Buffer argument p_img is nullptr"); + expect_failure(c_typed(&context, in1, 2, 1.0f, Buffer()), "Buffer argument fn4 is nullptr"); + expect_failure(c_typed(&context, in1, 2, 1.0f, Buffer()), "Buffer argument fn4 is nullptr"); + expect_failure(c_typed(&context, in1, 2, 1.0f, Buffer()), "Buffer argument fn4 is nullptr"); + expect_failure(c_typed(&context, in1, 2, 1.0f, Buffer()), "Buffer argument fn4 is nullptr"); + + // Note that since make_std_function doesn't take a JITUserContext, we aren't able to hook the error handler + // here, so all of these will just assert-fail and kill the test. We'll just skip the tests here, as it's + // exercised elsewhere enough. + } +} +} // namespace + +int main(int argc, char **argv) { + test_bad_untyped_calls(); + test_bad_typed_calls(); + + printf("Success!\n"); +} diff --git a/test/correctness/callable_generator.cpp b/test/correctness/callable_generator.cpp new file mode 100644 index 000000000000..2680d7778cd4 --- /dev/null +++ b/test/correctness/callable_generator.cpp @@ -0,0 +1,247 @@ +#include "Halide.h" +#include +#include + +using namespace Halide; + +namespace { + +void check(int r) { + assert(r == 0); +} + +bool custom_malloc_called = false; +bool custom_free_called = false; + +void *my_malloc(JITUserContext *user_context, size_t x) { + custom_malloc_called = true; + void *orig = malloc(x + 32); + void *ptr = (void *)((((size_t)orig + 32) >> 5) << 5); + ((void **)ptr)[-1] = orig; + return ptr; +} + +void my_free(JITUserContext *user_context, void *ptr) { + custom_free_called = true; + free(((void **)ptr)[-1]); +} + +int call_counter = 0; +extern "C" HALIDE_EXPORT_SYMBOL float my_extern_func(int x, float y) { + call_counter++; + return x * y; +} +HalideExtern_2(float, my_extern_func, int, float); + +} // namespace + +int main(int argc, char **argv) { + const Target t = get_jit_target_from_environment(); + const GeneratorContext context(t); + + { + class TestGen1 : public Generator { + public: + Input> img_{"img"}; + Input int_{"int"}; + Input float_{"float"}; + + Output> out_{"out"}; + + void generate() { + Var x("x"), y("y"); + out_(x, y) = img_(x, y) + cast(int_ / float_); + } + }; + + Buffer in1(10, 10); + Buffer in2(10, 10); + + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + in1(i, j) = i + j * 10; + in2(i, j) = i * 10 + j; + } + } + + auto gen = TestGen1::create(context); + Callable c = gen->compile_to_callable(); + + Buffer out1(10, 10); + check(c(in1, 42, 1.0f, out1)); + + Buffer out2(10, 10); + check(c(in2, 22, 2.0f, out2)); + + Buffer out3(10, 10); + check(c(in1, 12, 1.0f, out3)); + + Buffer out4(10, 10); + check(c(in2, 16, 1.0f, out4)); + + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + assert(out1(i, j) == i + j * 10 + 42); + assert(out2(i, j) == i * 10 + j + 11); + assert(out3(i, j) == i + j * 10 + 12); + assert(out4(i, j) == i * 10 + j + 16); + } + } + + // Test bounds inference + Buffer in_bounds(nullptr, 1, 1); + Buffer out_bounds(nullptr, 20, 20); + + check(c(in_bounds, 42, 1.0f, out_bounds)); + + assert(in_bounds.defined()); + assert(in_bounds.dim(0).extent() == 20); + assert(in_bounds.dim(1).extent() == 20); + assert(in1.dim(0).extent() == 10); + assert(in1.dim(1).extent() == 10); + } + + // Override Halide's malloc and free (except under wasm), + // make sure that Callable freezes the values + if (t.arch != Target::WebAssembly) { + custom_malloc_called = false; + custom_free_called = false; + + class TestGen2 : public Generator { + public: + Output> out_{"out"}; + + void generate() { + Var x("x"); + + Func f; + f(x) = x; + + out_(x) = f(x); + + f.compute_root(); + } + }; + + JITHandlers my_jit_handlers; + my_jit_handlers.custom_malloc = my_malloc; + my_jit_handlers.custom_free = my_free; + + auto gen = TestGen2::create(context); + Callable c = gen->compile_to_callable(&my_jit_handlers); + + Buffer im(100000); + check(c(im)); + + assert(custom_malloc_called); + assert(custom_free_called); + } + + // Check that Param works with Callables + if (t.arch != Target::WebAssembly) { + class TestGen3 : public Generator { + public: + GeneratorParam vectorize_{"vectorize", false}; + + Input handle_{"handle"}; + Output> out_{"out"}; + + void generate() { + Var x("x"); + + out_(x) = reinterpret(handle_); + if (vectorize_) { + out_.vectorize(x, 4); + } + } + }; + + auto gen_1 = TestGen3::create(context); + gen_1->vectorize_.set(false); + + auto gen_2 = TestGen3::create(context); + gen_2->vectorize_.set(true); + + Callable c1 = gen_1->compile_to_callable(); + Callable c2 = gen_2->compile_to_callable(); + + int foo = 0; + + Buffer out1(4); + // Create a dummy JITUserContext here just to test that + // passing one explicitly works correctly. + JITUserContext empty; + check(c1(&empty, &foo, out1)); + + Buffer out2(4); + check(c2(&foo, out2)); + + uint64_t correct = (uint64_t)((uintptr_t)(&foo)); + + for (int x = 0; x < out1.width(); x++) { + if (out1(x) != correct) { + printf("out1(%d) = %llu instead of %llu\n", + x, + (long long unsigned)out1(x), + (long long unsigned)correct); + exit(-1); + } + if (out2(x) != correct) { + printf("out2(%d) = %llu instead of %llu\n", + x, + (long long unsigned)out2(x), + (long long unsigned)correct); + exit(-1); + } + } + } + + // Check that JITExterns works with Callables + if (t.arch != Target::WebAssembly) { + call_counter = 0; + + class TestGen4 : public Generator { + public: + Output> out_{"out"}; + + void generate() { + Func f; + f.define_extern("extern_func", {user_context_value()}, Float(32), 2); + + Var x("x"), y("y"); + out_(x, y) = f(x, y); + } + }; + + Var x, y; + Func monitor; + monitor(x, y) = my_extern_func(x, cast(y)); + const std::map my_jit_externs = { + {"extern_func", JITExtern{monitor}}}; + + auto gen = TestGen4::create(context); + Callable c = gen->compile_to_callable(nullptr, &my_jit_externs); + + Buffer imf(32, 32); + check(c(imf)); + + // Check the result was what we expected + for (int i = 0; i < 32; i++) { + for (int j = 0; j < 32; j++) { + float correct = (float)(i * j); + float delta = imf(i, j) - correct; + if (delta < -0.001 || delta > 0.001) { + printf("imf[%d, %d] = %f instead of %f\n", i, j, imf(i, j), correct); + exit(-1); + } + } + } + + if (call_counter != 32 * 32) { + printf("In pipeline_set_jit_externs_func, my_func was called %d times instead of %d\n", call_counter, 32 * 32); + exit(-1); + } + } + + printf("Success!\n"); +} diff --git a/test/correctness/callable_typed.cpp b/test/correctness/callable_typed.cpp new file mode 100644 index 000000000000..d20fdbbd55c5 --- /dev/null +++ b/test/correctness/callable_typed.cpp @@ -0,0 +1,233 @@ +#include "Halide.h" +#include +#include + +using namespace Halide; + +namespace { + +void check(int r) { + assert(r == 0); +} + +bool custom_malloc_called = false; +bool custom_free_called = false; + +void *my_malloc(JITUserContext *user_context, size_t x) { + custom_malloc_called = true; + void *orig = malloc(x + 32); + void *ptr = (void *)((((size_t)orig + 32) >> 5) << 5); + ((void **)ptr)[-1] = orig; + return ptr; +} + +void my_free(JITUserContext *user_context, void *ptr) { + custom_free_called = true; + free(((void **)ptr)[-1]); +} + +void *mischievous_malloc(JITUserContext *user_context, size_t x) { + fprintf(stderr, "This should never get called\n"); + abort(); + return nullptr; +} + +int call_counter = 0; +extern "C" HALIDE_EXPORT_SYMBOL float my_extern_func(int x, float y) { + call_counter++; + return x * y; +} +HalideExtern_2(float, my_extern_func, int, float); + +} // namespace + +int main(int argc, char **argv) { + const Target t = get_jit_target_from_environment(); + + { + Param p_int(42); + Param p_float(1.0f); + ImageParam p_img(UInt(8), 2); + + Var x("x"), y("y"); + Func f("f"); + + f(x, y) = p_img(x, y) + cast(p_int / p_float); + + Buffer in1(10, 10); + Buffer in2(10, 10); + + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + in1(i, j) = i + j * 10; + in2(i, j) = i * 10 + j; + } + } + + // Note that we can't reliably infer the std::function<> signature in all + // cases, since some of the arguments may not be statically typed (e.g. Param), + // but `make_std_function` will fail at runtime if the template arguments + // don't match what is required. + auto c = f.compile_to_callable({p_img, p_int, p_float}, t) + .make_std_function, int, float, Buffer>(); + + { + Buffer out1(10, 10); + check(c(in1, 42, 1.0f, out1)); + + Buffer out2(10, 10); + check(c(in2, 22, 2.0f, out2)); + + Buffer out3(10, 10); + check(c(in1, 12, 1.0f, out3)); + + Buffer out4(10, 10); + check(c(in2, 16, 1.0f, out4)); + + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + assert(out1(i, j) == i + j * 10 + 42); + assert(out2(i, j) == i * 10 + j + 11); + assert(out3(i, j) == i + j * 10 + 12); + assert(out4(i, j) == i * 10 + j + 16); + } + } + } + + { + // Test bounds inference + Buffer in_bounds(nullptr, 1, 1); + Buffer out_bounds(nullptr, 20, 20); + + check(c(in_bounds, 42, 1.0f, out_bounds)); + + assert(in_bounds.defined()); + assert(in_bounds.dim(0).extent() == 20); + assert(in_bounds.dim(1).extent() == 20); + assert(in1.dim(0).extent() == 10); + assert(in1.dim(1).extent() == 10); + } + } + + // Override Halide's malloc and free (except under wasm), + // make sure that Callable freezes the values + if (t.arch != Target::WebAssembly) { + custom_malloc_called = false; + custom_free_called = false; + + Func f, g; + Var x; + + f(x) = x; + g(x) = f(x); + f.compute_root(); + + g.jit_handlers().custom_malloc = my_malloc; + g.jit_handlers().custom_free = my_free; + + auto c = g.compile_to_callable({}) + .make_std_function>(); + + // Changing g's handlers shouldn't affect any existing Callables + g.jit_handlers().custom_malloc = mischievous_malloc; + + Buffer im(100000); + check(c(im)); + + assert(custom_malloc_called); + assert(custom_free_called); + } + + // Check that Param works with Callables + if (t.arch != Target::WebAssembly) { + Func f("f"), g("g"); + Var x("x"); + Param handle("handle"); + + f(x) = reinterpret(handle); + + g(x) = reinterpret(handle); + g.vectorize(x, 4); + + // Create/use a dummy JITUserContext here just to test that + // passing one explicitly works correctly. + auto cf = f.compile_to_callable({handle}) + .make_std_function>(); + auto cg = g.compile_to_callable({handle}) + .make_std_function>(); + + int foo = 0; + + Buffer out1(4); + JITUserContext empty; + check(cf(&empty, &foo, out1)); + + Buffer out2(4); + check(cg(&foo, out2)); + + uint64_t correct = (uint64_t)((uintptr_t)(&foo)); + + for (int x = 0; x < out1.width(); x++) { + if (out1(x) != correct) { + printf("out1(%d) = %llu instead of %llu\n", + x, + (long long unsigned)out1(x), + (long long unsigned)correct); + exit(-1); + } + if (out2(x) != correct) { + printf("out2(%d) = %llu instead of %llu\n", + x, + (long long unsigned)out2(x), + (long long unsigned)correct); + exit(-1); + } + } + } + + // Check that JITExterns works with Callables + if (t.arch != Target::WebAssembly) { + call_counter = 0; + + std::vector args; + args.push_back(user_context_value()); + + Var x, y; + Func monitor; + monitor(x, y) = my_extern_func(x, cast(y)); + + Func f; + f.define_extern("extern_func", args, Float(32), 2); + + Pipeline p(f); + p.set_jit_externs({{"extern_func", JITExtern{monitor}}}); + + auto c = p.compile_to_callable({}) + .make_std_function>(); + + // Changing g's jit_externs shouldn't affect any existing Callables + p.set_jit_externs({}); + + Buffer imf(32, 32); + check(c(imf)); + + // Check the result was what we expected + for (int i = 0; i < 32; i++) { + for (int j = 0; j < 32; j++) { + float correct = (float)(i * j); + float delta = imf(i, j) - correct; + if (delta < -0.001 || delta > 0.001) { + printf("imf[%d, %d] = %f instead of %f\n", i, j, imf(i, j), correct); + exit(-1); + } + } + } + + if (call_counter != 32 * 32) { + printf("In pipeline_set_jit_externs_func, my_func was called %d times instead of %d\n", call_counter, 32 * 32); + exit(-1); + } + } + + printf("Success!\n"); +} diff --git a/test/correctness/compile_to_multitarget.cpp b/test/correctness/compile_to_multitarget.cpp index 7bbfbf082a09..4a582c463a47 100644 --- a/test/correctness/compile_to_multitarget.cpp +++ b/test/correctness/compile_to_multitarget.cpp @@ -5,12 +5,19 @@ using namespace Halide; -std::string get_fname(const std::string &base) { +// Given a path like /path/to/some/file.ext, return file.ext +// If the path contains no separators (/ or \), just return it as-is +std::string leaf_name(const std::string &path) { + size_t sep = std::min(path.rfind('/'), path.rfind('\\')); + return path.substr(sep == std::string::npos ? 0 : sep + 1); +} + +std::string get_output_path_prefix(const std::string &base) { return Internal::get_test_tmp_dir() + "halide_test_correctness_compile_to_multitarget_" + base; } void test_compile_to_static_library(Func j) { - std::string fname = get_fname("c1"); + std::string filename_prefix = get_output_path_prefix("c1"); const char *a = get_host_target().os == Target::Windows ? ".lib" : ".a"; std::vector targets = { @@ -19,14 +26,14 @@ void test_compile_to_static_library(Func j) { }; std::vector files; - files.push_back(fname + ".h"); - files.push_back(fname + a); + files.push_back(filename_prefix + ".h"); + files.push_back(filename_prefix + a); for (auto f : files) { Internal::ensure_no_file_exists(f); } - j.compile_to_multitarget_static_library(fname, j.infer_arguments(), targets); + j.compile_to_multitarget_static_library(filename_prefix, j.infer_arguments(), targets); for (auto f : files) { Internal::assert_file_exists(f); @@ -37,7 +44,7 @@ void test_compile_to_static_library(Func j) { } void test_compile_to_object_files(Func j) { - std::string fname = get_fname("c2"); + std::string filename_prefix = get_output_path_prefix("c2"); const char *o = get_host_target().os == Target::Windows ? ".obj" : ".o"; std::vector target_strings = { @@ -51,18 +58,18 @@ void test_compile_to_object_files(Func j) { } std::vector files; - files.push_back(fname + ".h"); - files.push_back(fname + "_runtime" + o); - files.push_back(fname + "_wrapper" + o); + files.push_back(filename_prefix + ".h"); + files.push_back(filename_prefix + "_runtime" + o); + files.push_back(filename_prefix + "_wrapper" + o); for (auto s : target_strings) { - files.push_back(fname + "-" + s + o); + files.push_back(filename_prefix + "-" + s + o); } for (auto f : files) { Internal::ensure_no_file_exists(f); } - j.compile_to_multitarget_object_files(fname, j.infer_arguments(), targets, target_strings); + j.compile_to_multitarget_object_files(filename_prefix, j.infer_arguments(), targets, target_strings); for (auto f : files) { Internal::assert_file_exists(f); @@ -70,7 +77,7 @@ void test_compile_to_object_files(Func j) { } void test_compile_to_object_files_no_runtime(Func j) { - std::string fname = get_fname("c3"); + std::string filename_prefix = get_output_path_prefix("c3"); const char *o = get_host_target().os == Target::Windows ? ".obj" : ".o"; std::vector target_strings = { @@ -84,17 +91,17 @@ void test_compile_to_object_files_no_runtime(Func j) { } std::vector files; - files.push_back(fname + ".h"); - files.push_back(fname + "_wrapper" + o); + files.push_back(filename_prefix + ".h"); + files.push_back(filename_prefix + "_wrapper" + o); for (auto s : target_strings) { - files.push_back(fname + "-" + s + o); + files.push_back(filename_prefix + "-" + s + o); } for (auto f : files) { Internal::ensure_no_file_exists(f); } - j.compile_to_multitarget_object_files(fname, j.infer_arguments(), targets, target_strings); + j.compile_to_multitarget_object_files(filename_prefix, j.infer_arguments(), targets, target_strings); for (auto f : files) { Internal::assert_file_exists(f); @@ -102,7 +109,7 @@ void test_compile_to_object_files_no_runtime(Func j) { } void test_compile_to_object_files_single_target(Func j) { - std::string fname = get_fname("c4"); + std::string filename_prefix = get_output_path_prefix("c4"); const char *o = get_host_target().os == Target::Windows ? ".obj" : ".o"; std::vector target_strings = { @@ -115,14 +122,14 @@ void test_compile_to_object_files_single_target(Func j) { } std::vector files; - files.push_back(fname + ".h"); - files.push_back(fname + o); + files.push_back(filename_prefix + ".h"); + files.push_back(filename_prefix + o); for (auto f : files) { Internal::ensure_no_file_exists(f); } - j.compile_to_multitarget_object_files(fname, j.infer_arguments(), targets, target_strings); + j.compile_to_multitarget_object_files(filename_prefix, j.infer_arguments(), targets, target_strings); for (auto f : files) { Internal::assert_file_exists(f); @@ -130,7 +137,7 @@ void test_compile_to_object_files_single_target(Func j) { } void test_compile_to_everything(Func j, bool do_object) { - std::string fname = get_fname(do_object ? "c5" : "c6"); + std::string filename_prefix = get_output_path_prefix(do_object ? "c5" : "c6"); const char *a = get_host_target().os == Target::Windows ? ".lib" : ".a"; const char *o = get_host_target().os == Target::Windows ? ".obj" : ".o"; @@ -156,18 +163,18 @@ void test_compile_to_everything(Func j, bool do_object) { ".schedule.h", "_schedule.py", a}) { if (do_object && !strcmp(ext, a)) continue; - files.push_back(fname + ext); + files.push_back(filename_prefix + ext); } if (do_object) { - files.push_back(fname + "_runtime" + o); - files.push_back(fname + "_wrapper" + o); + files.push_back(filename_prefix + "_runtime" + o); + files.push_back(filename_prefix + "_wrapper" + o); } // multi-file outputs for (const auto &s : target_strings) { for (const char *ext : {".s", ".bc", ".featurization", ".path_featurization", ".ll", ".stmt", ".stmt.html", o}) { if (!do_object && !strcmp(ext, o)) continue; - files.push_back(fname + "-" + s + ext); + files.push_back(filename_prefix + "-" + s + ext); } } @@ -182,26 +189,26 @@ void test_compile_to_everything(Func j, bool do_object) { return j.compile_to_module(args, name, target); }; std::map outputs = { - {OutputFileType::assembly, fname + ".s"}, // IsMulti - {OutputFileType::bitcode, fname + ".bc"}, // IsMulti - {OutputFileType::c_header, fname + ".h"}, // IsSingle - {OutputFileType::c_source, fname + ".halide_generated.cpp"}, // IsSingle - {OutputFileType::compiler_log, fname + ".halide_compiler_log"}, // IsSingle + {OutputFileType::assembly, filename_prefix + ".s"}, // IsMulti + {OutputFileType::bitcode, filename_prefix + ".bc"}, // IsMulti + {OutputFileType::c_header, filename_prefix + ".h"}, // IsSingle + {OutputFileType::c_source, filename_prefix + ".halide_generated.cpp"}, // IsSingle + {OutputFileType::compiler_log, filename_prefix + ".halide_compiler_log"}, // IsSingle // Note: compile_multitarget() doesn't produce cpp_stub output, // even if you pass this in. - // {OutputFileType::cpp_stub, fname + ".stub.h"}, // IsSingle - {OutputFileType::featurization, fname + ".featurization"}, // IsMulti - {OutputFileType::path_featurization, fname + ".path_featurization"}, // IsMulti - {OutputFileType::llvm_assembly, fname + ".ll"}, // IsMulti - {OutputFileType::object, fname + o}, // IsMulti - {OutputFileType::python_extension, fname + ".py.cpp"}, // IsSingle - {OutputFileType::pytorch_wrapper, fname + ".pytorch.h"}, // IsSingle - {OutputFileType::registration, fname + ".registration.cpp"}, // IsSingle - {OutputFileType::schedule, fname + ".schedule.h"}, // IsSingle - {OutputFileType::python_schedule, fname + "_schedule.py"}, // IsSingle - {OutputFileType::static_library, fname + a}, // IsSingle - {OutputFileType::stmt, fname + ".stmt"}, // IsMulti - {OutputFileType::stmt_html, fname + ".stmt.html"}, // IsMulti + // {OutputFileType::cpp_stub, filename_prefix + ".stub.h"}, // IsSingle + {OutputFileType::featurization, filename_prefix + ".featurization"}, // IsMulti + {OutputFileType::path_featurization, filename_prefix + ".path_featurization"}, // IsMulti + {OutputFileType::llvm_assembly, filename_prefix + ".ll"}, // IsMulti + {OutputFileType::object, filename_prefix + o}, // IsMulti + {OutputFileType::python_extension, filename_prefix + ".py.cpp"}, // IsSingle + {OutputFileType::pytorch_wrapper, filename_prefix + ".pytorch.h"}, // IsSingle + {OutputFileType::registration, filename_prefix + ".registration.cpp"}, // IsSingle + {OutputFileType::schedule, filename_prefix + ".schedule.h"}, // IsSingle + {OutputFileType::python_schedule, filename_prefix + "_schedule.py"}, // IsSingle + {OutputFileType::static_library, filename_prefix + a}, // IsSingle + {OutputFileType::stmt, filename_prefix + ".stmt"}, // IsMulti + {OutputFileType::stmt_html, filename_prefix + ".stmt.html"}, // IsMulti }; if (do_object) { outputs.erase(OutputFileType::static_library); @@ -214,7 +221,9 @@ void test_compile_to_everything(Func j, bool do_object) { // it exists or not -- so just fill in with arbitrary strings. return std::unique_ptr(new Internal::JSONCompilerLogger("generator_name", "function_name", "autoscheduler_name", Target(), "generator_args", false)); }; - compile_multitarget(fname, outputs, targets, target_strings, module_producer, compiler_logger_factory); + // The first argument to compile_multitarget is *function* name, not filename + std::string function_name = leaf_name(filename_prefix); + compile_multitarget(function_name, outputs, targets, target_strings, module_producer, compiler_logger_factory); for (auto f : files) { Internal::assert_file_exists(f); diff --git a/test/correctness/compute_at_split_rvar.cpp b/test/correctness/compute_at_split_rvar.cpp index bd12caf95cc6..d498c405cdf4 100644 --- a/test/correctness/compute_at_split_rvar.cpp +++ b/test/correctness/compute_at_split_rvar.cpp @@ -3,14 +3,8 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - int call_counter = 0; -extern "C" DLLEXPORT int count(int x) { +extern "C" HALIDE_EXPORT_SYMBOL int count(int x) { return call_counter++; } HalideExtern_1(int, count, int); diff --git a/test/correctness/compute_with.cpp b/test/correctness/compute_with.cpp index 8de8f660928d..fa5bb3c81ace 100644 --- a/test/correctness/compute_with.cpp +++ b/test/correctness/compute_with.cpp @@ -1,5 +1,6 @@ #include "Halide.h" #include "check_call_graphs.h" +#include "test_sharding.h" #include #include @@ -801,7 +802,7 @@ int multiple_outputs_on_gpu_test() { g.compute_with(f, x, LoopAlignStrategy::AlignEnd); - Realization r(f_im, g_im); + Realization r({f_im, g_im}); Pipeline({f, g}).realize(r); r[0].copy_to_host(); r[1].copy_to_host(); @@ -2206,152 +2207,58 @@ int two_compute_at_test() { } // namespace int main(int argc, char **argv) { - printf("Running split reorder test\n"); - if (split_test() != 0) { - return -1; - } - - printf("Running fuse test\n"); - if (fuse_test() != 0) { - return -1; - } - - printf("Running multiple fuse group test\n"); - if (multiple_fuse_group_test() != 0) { - return -1; - } - - printf("Running multiple outputs test\n"); - if (multiple_outputs_test() != 0) { - return -1; - } - - printf("Running double split fuse test\n"); - if (double_split_fuse_test() != 0) { - return -1; - } - - printf("Running vectorize test\n"); - if (vectorize_test() != 0) { - return -1; - } - - /* - * Note: we are deprecating skipping parts of a fused group in favor of - * cloning funcs in particular stages via a new (clone_)in overload. - * TODO: remove this code when the new clone_in is implemented. - */ - // printf("Running some are skipped test\n"); - // if (some_are_skipped_test() != 0) { - // return -1; - // } - - printf("Running rgb to yuv420 test\n"); - if (rgb_yuv420_test() != 0) { - return -1; - } - - printf("Running with specialization test\n"); - if (with_specialization_test() != 0) { - return -1; - } - - printf("Running fuse compute at test\n"); - if (fuse_compute_at_test() != 0) { - return -1; - } - - printf("Running nested compute with test\n"); - if (nested_compute_with_test() != 0) { - return -1; - } - - printf("Running mixed tile factor test\n"); - if (mixed_tile_factor_test() != 0) { - return -1; - } - - // NOTE: disabled because it generates OOB (see #4751 for discussion). - /* - printf("Running only some are tiled test\n"); - if (only_some_are_tiled_test() != 0) { - return -1; - } - */ - printf("Running multiple outputs on gpu test\n"); - if (multiple_outputs_on_gpu_test() != 0) { - return -1; - } - - printf("Running multi tile mixed tile factor test\n"); - if (multi_tile_mixed_tile_factor_test() != 0) { - return -1; - } - - printf("Running update stage test\n"); - if (update_stage_test() != 0) { - return -1; - } - - printf("Running update stage2 test\n"); - if (update_stage2_test() != 0) { - return -1; - } - - printf("Running update stage3 test\n"); - if (update_stage3_test() != 0) { - return -1; - } - - printf("Running update stage pairwise test\n"); - if (update_stage_pairwise_test() != 0) { - return -1; - } - - // I think this should work, but there is an overzealous check somewhere. - // printf("Running update stage pairwise zigzag test\n"); - // if (update_stage_pairwise_zigzag_test() != 0) { - // return -1; - // } - - printf("Running update stage diagonal test\n"); - if (update_stage_diagonal_test() != 0) { - return -1; - } - - printf("Running update stage rfactor test\n"); - if (update_stage_rfactor_test() != 0) { - return -1; - } - - printf("Running vectorize inlined test\n"); - if (vectorize_inlined_test() != 0) { - return -1; - } - - printf("Running mismatching splits test\n"); - if (mismatching_splits_test() != 0) { - return -1; - } - - printf("Running different arg number compute_at test\n"); - if (different_arg_num_compute_at_test() != 0) { - return -1; - } - - printf("Running store_at different levels test\n"); - if (store_at_different_levels_test() != 0) { - return -1; - } + struct Task { + std::string desc; + std::function fn; + }; - printf("Running rvar bounds test\n"); - if (rvar_bounds_test() != 0) { - return -1; - } + std::vector tasks = { + {"split reorder test", split_test}, + {"fuse test", fuse_test}, + {"multiple fuse group test", multiple_fuse_group_test}, + {"multiple outputs test", multiple_outputs_test}, + {"double split fuse test", double_split_fuse_test}, + {"vectorize test", vectorize_test}, + // + // Note: we are deprecating skipping parts of a fused group in favor of + // cloning funcs in particular stages via a new (clone_)in overload. + // TODO: remove this code when the new clone_in is implemented. + // + // {"some are skipped test", some_are_skipped_test}, + {"rgb to yuv420 test", rgb_yuv420_test}, + {"with specialization test", with_specialization_test}, + {"fuse compute at test", fuse_compute_at_test}, + {"nested compute with test", nested_compute_with_test}, + {"mixed tile factor test", mixed_tile_factor_test}, + // NOTE: disabled because it generates OOB (see #4751 for discussion). + // {"only some are tiled test", only_some_are_tiled_test}, + {"multiple outputs on gpu test", multiple_outputs_on_gpu_test}, + {"multi tile mixed tile factor test", multi_tile_mixed_tile_factor_test}, + {"update stage test", update_stage_test}, + {"update stage2 test", update_stage2_test}, + {"update stage3 test", update_stage3_test}, + {"update stage pairwise test", update_stage_pairwise_test}, + // I think this should work, but there is an overzealous check somewhere. + // {"update stage pairwise zigzag test", update_stage_pairwise_zigzag_test}, + {"update stage diagonal test", update_stage_diagonal_test}, + {"update stage rfactor test", update_stage_rfactor_test}, + {"vectorize inlined test", vectorize_inlined_test}, + {"mismatching splits test", mismatching_splits_test}, + {"different arg number compute_at test", different_arg_num_compute_at_test}, + {"store_at different levels test", store_at_different_levels_test}, + {"rvar bounds test", rvar_bounds_test}, + {"two_compute_at test", two_compute_at_test}, + }; - printf("Running two_compute_at test\n"); - if (two_compute_at_test() != 0) { - return -1; + using Sharder = Halide::Internal::Test::Sharder; + Sharder sharder; + for (size_t t = 0; t < tasks.size(); t++) { + if (!sharder.should_run(t)) continue; + const auto &task = tasks.at(t); + std::cout << task.desc << "\n"; + if (task.fn() != 0) { + return -1; + } } printf("Success!\n"); diff --git a/test/correctness/concat.cpp b/test/correctness/concat.cpp index fd5490e782cd..206c65449055 100644 --- a/test/correctness/concat.cpp +++ b/test/correctness/concat.cpp @@ -2,14 +2,8 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - int count[2]; -extern "C" DLLEXPORT int call_counter(int slot, int val) { +extern "C" HALIDE_EXPORT_SYMBOL int call_counter(int slot, int val) { count[slot]++; return val; } diff --git a/test/correctness/constraints.cpp b/test/correctness/constraints.cpp index a08ba3a6b34b..6899c007d606 100644 --- a/test/correctness/constraints.cpp +++ b/test/correctness/constraints.cpp @@ -8,7 +8,7 @@ using namespace Halide; bool error_occurred = false; void my_error_handler(JITUserContext *user_context, const char *msg) { - //printf("%s\n", msg); + // printf("%s\n", msg); error_occurred = true; } diff --git a/test/correctness/convolution.cpp b/test/correctness/convolution.cpp index 5721f3191a83..5e70a58abe9d 100644 --- a/test/correctness/convolution.cpp +++ b/test/correctness/convolution.cpp @@ -5,7 +5,7 @@ using namespace Halide; int main(int argc, char **argv) { - //int W = 64*3, H = 64*3; + // int W = 64*3, H = 64*3; const int W = 128, H = 48; Buffer in(W, H); diff --git a/test/correctness/convolution_multiple_kernels.cpp b/test/correctness/convolution_multiple_kernels.cpp index c494cc99bdd8..0b761d314a71 100644 --- a/test/correctness/convolution_multiple_kernels.cpp +++ b/test/correctness/convolution_multiple_kernels.cpp @@ -5,7 +5,7 @@ using namespace Halide; int main(int argc, char **argv) { - //int W = 64*3, H = 64*3; + // int W = 64*3, H = 64*3; const int W = 64, H = 16; Buffer in(W, H); diff --git a/test/correctness/custom_auto_scheduler.cpp b/test/correctness/custom_auto_scheduler.cpp index 32eec8b25dae..cda182861340 100644 --- a/test/correctness/custom_auto_scheduler.cpp +++ b/test/correctness/custom_auto_scheduler.cpp @@ -6,7 +6,11 @@ int call_count = 0; void inline_everything(const Pipeline &, const Target &, +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API const MachineParams &, +#else + const AutoschedulerParams &, +#endif AutoSchedulerResults *) { call_count++; // Inlining everything is really easy. @@ -22,13 +26,22 @@ int main(int argc, char **argv) { Func f; Var x; f(x) = 3; - Pipeline(f).auto_schedule(kSchedulerName, Target("host")); - - Pipeline::set_default_autoscheduler_name(kSchedulerName); Func g; g(x) = 3; - Pipeline(g).auto_schedule(Target("host")); + + Target t("host"); + +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + Pipeline(f).auto_schedule(kSchedulerName, t); + + Pipeline::set_default_autoscheduler_name(kSchedulerName); + Pipeline(g).auto_schedule(t); +#else + AutoschedulerParams autoscheduler_params(kSchedulerName); + Pipeline(f).apply_autoscheduler(t, autoscheduler_params); + Pipeline(g).apply_autoscheduler(t, autoscheduler_params); +#endif if (call_count != 2) { printf("Should have called the custom autoscheduler twice. Instead called it %d times\n", call_count); diff --git a/test/correctness/custom_lowering_pass.cpp b/test/correctness/custom_lowering_pass.cpp index 09de3ec23dfb..b446df1e8ed4 100644 --- a/test/correctness/custom_lowering_pass.cpp +++ b/test/correctness/custom_lowering_pass.cpp @@ -24,14 +24,8 @@ class CheckForFloatDivision : public IRMutator { // A mutator that injects code that counts floating point multiplies, // and an extern function that it calls out to for the accounting. -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - int multiply_count = 0; -extern "C" DLLEXPORT float record_float_mul(float arg) { +extern "C" HALIDE_EXPORT_SYMBOL float record_float_mul(float arg) { multiply_count++; return arg; } diff --git a/test/correctness/debug_to_file_multiple_outputs.cpp b/test/correctness/debug_to_file_multiple_outputs.cpp index ab012b3e4cbb..f3b3ddcf7a71 100644 --- a/test/correctness/debug_to_file_multiple_outputs.cpp +++ b/test/correctness/debug_to_file_multiple_outputs.cpp @@ -37,7 +37,7 @@ int main(int argc, char **argv) { Buffer f_im(size_x + 1, size_y); Buffer g_im(size_x, size_y), h_im(size_x, size_y); - Realization r(f_im, g_im, h_im); + Realization r({f_im, g_im, h_im}); p.realize(r); } diff --git a/test/correctness/early_out.cpp b/test/correctness/early_out.cpp new file mode 100644 index 000000000000..1a716c87f1f2 --- /dev/null +++ b/test/correctness/early_out.cpp @@ -0,0 +1,49 @@ +#include "Halide.h" + +using namespace Halide; + +int main(int argc, char **argv) { + // This is a test case that performs an or reduction using a where clause to + // get early-out behavior on the reduction loop. It triggered two bugs. + // + // First, there's a param that's only used in a specialization of a wrapper + // func, and this wasn't picked up by InferArguments. + // + // Second, there's a variable-free condition + // that feeds into bounds inference (test()), and bounds inference assumed + // that being variable-free meant it only depended on params and could be + // lifted out into a bounds expression. + // + // Both of these bugs caused compilation failures, so this test just + // verifies that things compile. + + Param height; + + Var y; + + Func test_rows("test_rows"); + test_rows(y) = y < 100; + + Func test("test"); + test() = cast(false); + RDom ry(0, 1024); + ry.where(!test()); + test() = test_rows(ry); + + Func output; + output() = select(test(), cast(0), cast(1)); + + Expr num_slices = (height + 255) / 256; + Expr slice_size = (height + num_slices - 1) / num_slices; + + test_rows.in() + .compute_root() + .specialize(height > slice_size) + .parallel(y, slice_size, TailStrategy::ShiftInwards); + + output.compile_jit(); + + printf("Success!\n"); + + return 0; +} diff --git a/test/correctness/exception.cpp b/test/correctness/exception.cpp index e692e933de1f..5da3677ade1f 100644 --- a/test/correctness/exception.cpp +++ b/test/correctness/exception.cpp @@ -39,7 +39,7 @@ int main(int argc, char **argv) { error = true; std::cout << "Expected compile error:\n" << e.what() << "\n"; - }; + } // We should have entered the catch block check_error(error); diff --git a/test/correctness/extern_bounds_inference.cpp b/test/correctness/extern_bounds_inference.cpp index dab90168256f..fdf30f5dea12 100644 --- a/test/correctness/extern_bounds_inference.cpp +++ b/test/correctness/extern_bounds_inference.cpp @@ -1,14 +1,8 @@ #include "Halide.h" #include -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - // An extern stage that translates. -extern "C" DLLEXPORT int translate(halide_buffer_t *in, int dx, int dy, halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int translate(halide_buffer_t *in, int dx, int dy, halide_buffer_t *out) { if (in->is_bounds_query()) { in->dim[0].min = out->dim[0].min + dx; diff --git a/test/correctness/extern_consumer.cpp b/test/correctness/extern_consumer.cpp index 76543392a019..f9db356cfb5a 100644 --- a/test/correctness/extern_consumer.cpp +++ b/test/correctness/extern_consumer.cpp @@ -5,15 +5,9 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - -extern "C" DLLEXPORT int dump_to_file(halide_buffer_t *input, const char *filename, - int desired_min, int desired_extent, - halide_buffer_t *) { +extern "C" HALIDE_EXPORT_SYMBOL int dump_to_file(halide_buffer_t *input, const char *filename, + int desired_min, int desired_extent, + halide_buffer_t *) { // Note the final output buffer argument is unused. if (input->is_bounds_query()) { // Request some range of the input buffer diff --git a/test/correctness/extern_consumer_tiled.cpp b/test/correctness/extern_consumer_tiled.cpp index 27bf9c7748aa..39d25157d129 100644 --- a/test/correctness/extern_consumer_tiled.cpp +++ b/test/correctness/extern_consumer_tiled.cpp @@ -5,13 +5,7 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - -extern "C" DLLEXPORT int copy_plus_xcoord(halide_buffer_t *input, int tile_extent_x, int tile_extent_y, halide_buffer_t *output) { +extern "C" HALIDE_EXPORT_SYMBOL int copy_plus_xcoord(halide_buffer_t *input, int tile_extent_x, int tile_extent_y, halide_buffer_t *output) { // Note the final output buffer argument is unused. if (input->is_bounds_query()) { for (int d = 0; d < 2; d++) { diff --git a/test/correctness/extern_error.cpp b/test/correctness/extern_error.cpp index f2a943f48ae9..d5df16e87649 100644 --- a/test/correctness/extern_error.cpp +++ b/test/correctness/extern_error.cpp @@ -3,20 +3,14 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - bool extern_error_called = false; -extern "C" DLLEXPORT int extern_error(JITUserContext *user_context, halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int extern_error(JITUserContext *user_context, halide_buffer_t *out) { extern_error_called = true; return -1; } bool error_occurred = false; -extern "C" DLLEXPORT void my_halide_error(JITUserContext *user_context, const char *msg) { +extern "C" HALIDE_EXPORT_SYMBOL void my_halide_error(JITUserContext *user_context, const char *msg) { printf("Expected: %s\n", msg); error_occurred = true; } diff --git a/test/correctness/extern_output_expansion.cpp b/test/correctness/extern_output_expansion.cpp index 6d8018950369..1d243d1949c9 100644 --- a/test/correctness/extern_output_expansion.cpp +++ b/test/correctness/extern_output_expansion.cpp @@ -1,14 +1,8 @@ #include "Halide.h" #include -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - // out(x) = in(x) * x; -extern "C" DLLEXPORT int extern_stage(halide_buffer_t *in, halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int extern_stage(halide_buffer_t *in, halide_buffer_t *out) { assert(in->type == halide_type_of()); assert(out->type == halide_type_of()); if (in->host == nullptr || out->host == nullptr) { diff --git a/test/correctness/extern_partial.cpp b/test/correctness/extern_partial.cpp index 909ca9475611..670aac4ace5c 100644 --- a/test/correctness/extern_partial.cpp +++ b/test/correctness/extern_partial.cpp @@ -5,13 +5,7 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - -extern "C" DLLEXPORT int copy_row_plus_xcoord(halide_buffer_t *input, halide_buffer_t *output) { +extern "C" HALIDE_EXPORT_SYMBOL int copy_row_plus_xcoord(halide_buffer_t *input, halide_buffer_t *output) { // Note the final output buffer argument is unused. if (input->is_bounds_query()) { for (int d = 0; d < 2; d++) { diff --git a/test/correctness/extern_producer.cpp b/test/correctness/extern_producer.cpp index 51b7490b25d3..4005c1beb653 100644 --- a/test/correctness/extern_producer.cpp +++ b/test/correctness/extern_producer.cpp @@ -3,12 +3,6 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - // Some helper functions for rounding int round_down(int, int); int round_up(int x, int m) { @@ -29,7 +23,7 @@ int round_down(int x, int m) { // Imagine that this loads from a file, or tiled storage. Here we'll just fill in the data using a // periodic integer function. -extern "C" DLLEXPORT int make_data(halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int make_data(halide_buffer_t *out) { static int desired_row_extent = 0; if (out->is_bounds_query()) { // Bounds query mode. To make life interesting, let's add some @@ -72,7 +66,7 @@ extern "C" DLLEXPORT int make_data(halide_buffer_t *out) { // Imagine that this loads from a file, or tiled storage. Here we'll just fill in the data using a // periodic integer function. -extern "C" DLLEXPORT int make_data_multi(halide_buffer_t *out1, halide_buffer_t *out2) { +extern "C" HALIDE_EXPORT_SYMBOL int make_data_multi(halide_buffer_t *out1, halide_buffer_t *out2) { if (!out1->host || !out2->host) { // Bounds query mode. We're ok with any requested output size (Halide guarantees they match). return 0; diff --git a/test/correctness/extern_reorder_storage.cpp b/test/correctness/extern_reorder_storage.cpp index 2a7374101f77..036f3adc5d0e 100644 --- a/test/correctness/extern_reorder_storage.cpp +++ b/test/correctness/extern_reorder_storage.cpp @@ -1,14 +1,8 @@ #include "Halide.h" #include -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - // An extern stage that translates. -extern "C" DLLEXPORT int copy_and_check_strides(halide_buffer_t *in, halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int copy_and_check_strides(halide_buffer_t *in, halide_buffer_t *out) { if (in->is_bounds_query()) { for (int i = 0; i < 2; i++) { in->dim[i].min = out->dim[i].min; diff --git a/test/correctness/extern_sort.cpp b/test/correctness/extern_sort.cpp index 4ead51c43486..f44568c52fc1 100644 --- a/test/correctness/extern_sort.cpp +++ b/test/correctness/extern_sort.cpp @@ -4,14 +4,8 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - // Use an extern stage to do a sort -extern "C" DLLEXPORT int sort_buffer(halide_buffer_t *in, halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int sort_buffer(halide_buffer_t *in, halide_buffer_t *out) { if (in->is_bounds_query()) { in->dim[0].min = out->dim[0].min; in->dim[0].extent = out->dim[0].extent; diff --git a/test/correctness/extern_stage.cpp b/test/correctness/extern_stage.cpp index 8000588f5a69..a30c25804bfb 100644 --- a/test/correctness/extern_stage.cpp +++ b/test/correctness/extern_stage.cpp @@ -1,13 +1,7 @@ #include "Halide.h" #include -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - -extern "C" DLLEXPORT int flip_x(halide_buffer_t *in1, halide_buffer_t *in2, halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int flip_x(halide_buffer_t *in1, halide_buffer_t *in2, halide_buffer_t *out) { int min = out->dim[0].min; int max = out->dim[0].min + out->dim[0].extent - 1; @@ -31,7 +25,7 @@ extern "C" DLLEXPORT int flip_x(halide_buffer_t *in1, halide_buffer_t *in2, hali // We don't mutate the output buffer, because we can handle // any size output. - //printf("Bounds inference flip_x over [%d %d] requires [%d %d]\n", min, extent, flipped_min, extent); + // printf("Bounds inference flip_x over [%d %d] requires [%d %d]\n", min, extent, flipped_min, extent); } else { assert(in1->type == halide_type_of()); assert(in2->type == halide_type_of()); diff --git a/test/correctness/extern_stage_on_device.cpp b/test/correctness/extern_stage_on_device.cpp index 4ee2c3e77728..1397e22193a3 100644 --- a/test/correctness/extern_stage_on_device.cpp +++ b/test/correctness/extern_stage_on_device.cpp @@ -4,19 +4,13 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - // An extern stage implemented by a Halide pipeline running // either on host or device. The outer Halide filter must // override the "device_api" parameter of Func::define_extern // when using the extern_stage on device. -extern "C" DLLEXPORT int extern_stage(int extern_on_device, - int outer_filter_on_device, - halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int extern_stage(int extern_on_device, + int outer_filter_on_device, + halide_buffer_t *out) { if (!out->is_bounds_query()) { if (extern_on_device > 0 && outer_filter_on_device > 0) { // If both the extern and the outer filter are on running on diff --git a/test/correctness/external_code.cpp b/test/correctness/external_code.cpp deleted file mode 100644 index 87f8bcfa1437..000000000000 --- a/test/correctness/external_code.cpp +++ /dev/null @@ -1,66 +0,0 @@ -#include "Halide.h" -#include "halide_test_dirs.h" - -#include -#include - -#include -#include - -using namespace Halide; - -int main(int argc, char **argv) { - if (get_jit_target_from_environment().arch == Target::WebAssembly) { - printf("[SKIP] Skipping test for WebAssembly as it does not support ExternalCode::bitcode_wrapper().\n"); - return 0; - } - - Var x("x"), y("y"); - Func f("f"); - - f(x, y) = 42; - - Target target = get_jit_target_from_environment(); - - std::string bitcode_file = Internal::get_test_tmp_dir() + "extern.bc"; - f.compile_to_bitcode(bitcode_file, {}, "extern", target); - - std::vector bitcode; - std::ifstream bitcode_stream(bitcode_file, std::ios::in | std::ios::binary); - bitcode_stream.seekg(0, std::ios::end); - bitcode.resize(bitcode_stream.tellg()); - bitcode_stream.seekg(0, std::ios::beg); - bitcode_stream.read(reinterpret_cast(&bitcode[0]), bitcode.size()); - - ExternalCode external_code = - ExternalCode::bitcode_wrapper(target, bitcode, "extern"); - - Func f_extern; - f_extern.define_extern("extern", {}, type_of(), 2); - - Func result; - result(x, y) = f_extern(x, y); - - Module module = result.compile_to_module({}, "forty_two", target); - - module.append(external_code); - - auto forty_two = module.get_function_by_name("forty_two"); - - Internal::JITModule jit_module(module, forty_two, {}); - - auto main_function = (int (*)(halide_buffer_t * buf)) jit_module.main_function(); - Buffer buf(16, 16); - - int ret_code = main_function(buf.raw_buffer()); - - assert(ret_code == 0); - for (int i = 0; i < 16; i++) { - for (int j = 0; j < 16; j++) { - assert(buf(i, j) == 42); - } - } - - printf("Success!\n"); - return 0; -} diff --git a/test/correctness/extract_concat_bits.cpp b/test/correctness/extract_concat_bits.cpp new file mode 100644 index 000000000000..20e11a75e1ce --- /dev/null +++ b/test/correctness/extract_concat_bits.cpp @@ -0,0 +1,152 @@ +#include "Halide.h" + +using namespace Halide; + +class CountOps : public Internal::IRMutator { + Expr visit(const Internal::Reinterpret *op) override { + std::cerr << Expr(op) << " " << op->type.lanes() << " " << op->value.type().lanes() << "\n"; + if (op->type.lanes() != op->value.type().lanes()) { + std::cerr << "Got one\n"; + reinterprets++; + } + return Internal::IRMutator::visit(op); + } + + Expr visit(const Internal::Call *op) override { + if (op->is_intrinsic(Internal::Call::concat_bits)) { + concats++; + } else if (op->is_intrinsic(Internal::Call::extract_bits)) { + extracts++; + } + return Internal::IRMutator::visit(op); + } + +public: + int extracts = 0, concats = 0, reinterprets = 0; +}; + +int main(int argc, char **argv) { + for (bool vectorize : {false, true}) { + // Reinterpret an array of a wide type as a larger array of a smaller type + Func f, g; + Var x; + + f(x) = cast(x); + + // Reinterpret to a narrower type. + g(x) = extract_bits(f(x / 4), 8 * (x % 4)); + + f.compute_root(); + + if (vectorize) { + f.vectorize(x, 8); + // The align_bounds directive is critical so that the x%4 term above collapses. + g.align_bounds(x, 4).vectorize(x, 32); + + // An alternative to the align_bounds call: + // g.output_buffer().dim(0).set_min(0); + } + + CountOps counter; + g.add_custom_lowering_pass(&counter, nullptr); + + Buffer out = g.realize({1024}); + std::cerr << counter.extracts << " " << counter.reinterprets << " " << counter.concats << "\n"; + + if (vectorize) { + if (counter.extracts > 0) { + printf("Saw an unwanted extract_bits call in lowered code\n"); + return -1; + } else if (counter.reinterprets == 0) { + printf("Did not see a vector reinterpret in lowered code\n"); + return -1; + } + } + + for (uint32_t i = 0; i < (uint32_t)out.width(); i++) { + uint8_t correct = (i / 4) >> (8 * (i % 4)); + if (out(i) != correct) { + printf("out(%d) = %d instead of %d\n", i, out(i), correct); + return -1; + } + } + } + + for (bool vectorize : {false, true}) { + // Reinterpret an array of a narrow type as a smaller array of a wide type + Func f, g; + Var x; + + f(x) = cast(x); + + g(x) = concat_bits({f(4 * x), f(4 * x + 1), f(4 * x + 2), f(4 * x + 3)}); + + f.compute_root(); + + if (vectorize) { + f.vectorize(x, 32); + g.vectorize(x, 8); + } + + CountOps counter; + g.add_custom_lowering_pass(&counter, nullptr); + + Buffer out = g.realize({64}); + + if (counter.concats > 0) { + printf("Saw an unwanted concat_bits call in lowered code\n"); + return -1; + } else if (counter.reinterprets == 0) { + printf("Did not see a vector reinterpret in lowered code\n"); + return -1; + } + + for (int i = 0; i < 64; i++) { + for (int b = 0; b < 4; b++) { + uint8_t correct = i * 4 + b; + uint8_t result = (out(i) >> (b * 8)) & 0xff; + if (result != correct) { + printf("out(%d) byte %d = %d instead of %d\n", i, b, result, correct); + return -1; + } + } + } + } + + // Also test cases that aren't expected to fold into reinterprets + { + Func f; + Var x("x"); + f(x) = cast(x); + + auto check = [&](const Expr &a, const Expr &b) { + Func g; + g(x) = cast(a == b); + Buffer out = g.realize({1024}); + for (int i = 0; i < out.width(); i++) { + if (out(i) == 0) { + std::cerr << "Mismatch between: " << a << " and " << b << " when x == " << i << "\n"; + exit(-1); + } + } + }; + + // concat_bits is little-endian + check(concat_bits({f(x), cast(37)}), cast(f(x)) + (37 << 16)); + check(concat_bits({cast(0), f(x), cast(0), cast(0)}), cast(UInt(64), f(x)) << 16); + + // extract_bits is equivalent to right shifting and then casting to a narrower type + check(extract_bits(f(x), 3), cast(f(x) >> 3)); + + // Extract bits zero-fills out-of-range bits + check(extract_bits(f(x), 3), f(x) >> 3); + check(extract_bits(f(x), 8), (f(x) >> 8) & 0xff); + check(extract_bits(f(x), -1), cast(f(x)) << 1); + + // MSB of the mantissa of an ieee float + check(extract_bits(cast(f(x)), 15), cast(reinterpret(cast(f(x))) >> 15)); + } + + printf("Success!\n"); + return 0; +} diff --git a/test/correctness/float16_t_neon_op_check.cpp b/test/correctness/float16_t_neon_op_check.cpp index 40593e60e19d..bca7d6be8765 100644 --- a/test/correctness/float16_t_neon_op_check.cpp +++ b/test/correctness/float16_t_neon_op_check.cpp @@ -324,9 +324,11 @@ int main(int argc, char **argv) { // Only for 64bit target with fp16 feature if (!(hl_target.arch == Target::ARM && hl_target.bits == 64 && hl_target.has_feature(Target::ARMFp16))) { + Halide::Internal::Test::Sharder::accept_sharded_status(); printf("[SKIP] To run this test, set HL_TARGET=arm-64--arm_fp16. \n"); return 0; } + // Create Test Object // Use smaller dimension than default(768, 128) to avoid fp16 overflow in reduction test case SimdOpCheck test(hl_target, 384, 32); @@ -337,27 +339,12 @@ int main(int argc, char **argv) { if (argc > 1) { test.filter = argv[1]; - test.set_num_threads(1); } if (getenv("HL_SIMD_OP_CHECK_FILTER")) { test.filter = getenv("HL_SIMD_OP_CHECK_FILTER"); } - // TODO: multithreading here is the cause of https://github.com/halide/Halide/issues/3669; - // the fundamental issue is that we make one set of ImageParams to construct many - // Exprs, then realize those Exprs on arbitrary threads; it is known that sharing - // one Func across multiple threads is not guaranteed to be safe, and indeed, TSAN - // reports data races, of which some are likely 'benign' (e.g. Function.freeze) but others - // are highly suspect (e.g. Function.lock_loop_levels). Since multithreading here - // was added just to avoid having this test be the last to finish, the expedient 'fix' - // for now is to remove the multithreading. A proper fix could be made by restructuring this - // test so that every Expr constructed for testing was guaranteed to share no Funcs - // (Function.deep_copy() perhaps). Of course, it would also be desirable to allow Funcs, Exprs, etc - // to be usable across multiple threads, but that is a major undertaking that is - // definitely not worthwhile for present Halide usage patterns. - test.set_num_threads(1); - if (argc > 2) { // Don't forget: if you want to run the standard tests to a specific output // directory, you'll need to invoke with the first arg enclosed diff --git a/test/correctness/gpu_allocation_cache.cpp b/test/correctness/gpu_allocation_cache.cpp index a0f104b83f7a..97215b95aea5 100644 --- a/test/correctness/gpu_allocation_cache.cpp +++ b/test/correctness/gpu_allocation_cache.cpp @@ -125,10 +125,10 @@ int main(int argc, char **argv) { }; // First run them serially (compilation of a Func isn't thread-safe). - //test1(true); - //test2(true); - //test3(true); - //return 0; + // test1(true); + // test2(true); + // test3(true); + // return 0; // Now run all at the same time to check for concurrency issues. diff --git a/test/correctness/gpu_mixed_dimensionality.cpp b/test/correctness/gpu_mixed_dimensionality.cpp index d2644e14bc63..aabd118271ec 100644 --- a/test/correctness/gpu_mixed_dimensionality.cpp +++ b/test/correctness/gpu_mixed_dimensionality.cpp @@ -28,8 +28,8 @@ int main(int argc, char **argv) { h.compute_at(out, x).gpu_threads(x, y); h.update().gpu_threads(x); // TODO: NormalizeDimensionality in FuseGPUThreadLoops.cpp doesn't work in the following case. - //g.compute_at(h, y).gpu_threads(x); - //g.update(); + // g.compute_at(h, y).gpu_threads(x); + // g.update(); g.compute_at(h, x); g.update(); f.compute_at(g, x); diff --git a/test/correctness/gpu_vectorize.cpp b/test/correctness/gpu_vectorize.cpp index 407342cca1af..2e0ffeebc3b5 100644 --- a/test/correctness/gpu_vectorize.cpp +++ b/test/correctness/gpu_vectorize.cpp @@ -70,6 +70,42 @@ int main(int argc, char **argv) { } } } + { + Var x("x"), y("y"), xi("xi"), yi("yi"); + Func f("f"); + ImageParam im(Float(32), 2); + + printf("Defining function...\n"); + + f(x, y) = select(im(x, y) > 32.0f, 1.0f, -1.0f) + im(x, y); + + Target target = get_jit_target_from_environment(); + if (target.has_gpu_feature()) { + f.gpu_tile(x, y, xi, yi, 8, 8, TailStrategy::GuardWithIf).vectorize(xi, 4, TailStrategy::GuardWithIf); + } + + printf("Realizing function...\n"); + Buffer input_img(32, 32); + for (int i = 0; i < 32; i++) { + for (int j = 0; j < 32; j++) { + input_img(i, j) = i + j; + } + } + im.set(input_img); + + Buffer imf = f.realize({32, 32}, target); + + // Check the result was what we expected + for (int i = 0; i < 32; i++) { + for (int j = 0; j < 32; j++) { + float correct = (i + j > 32 ? 1.0f : -1.0f) + i + j; + if (fabs(imf(i, j) - correct) > 0.001f) { + printf("imf[%d, %d] = %f instead of %f\n", i, j, imf(i, j), correct); + return -1; + } + } + } + } printf("Success!\n"); return 0; diff --git a/test/correctness/handle.cpp b/test/correctness/handle.cpp index 67d7a1964e8d..075c7f17e9c8 100644 --- a/test/correctness/handle.cpp +++ b/test/correctness/handle.cpp @@ -3,15 +3,9 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - // Make a custom strlen so that it always returns a 32-bit int, // instead of switching based on bit-width. -extern "C" DLLEXPORT int my_strlen(const char *c) { +extern "C" HALIDE_EXPORT_SYMBOL int my_strlen(const char *c) { int l = 0; while (*c) { c++; diff --git a/test/correctness/hexagon_scatter.cpp b/test/correctness/hexagon_scatter.cpp index 88c9f998b681..6fde0669ac19 100644 --- a/test/correctness/hexagon_scatter.cpp +++ b/test/correctness/hexagon_scatter.cpp @@ -97,6 +97,11 @@ int test() { } int main() { + if (!get_jit_target_from_environment().has_feature(Target::HVX)) { + printf("[SKIP] hexagon_scatter is only useful when targeting HVX.\n"); + return 0; + } + if (!test() || !test() || !test() || diff --git a/test/correctness/host_alignment.cpp b/test/correctness/host_alignment.cpp index 4dc7bf40a376..20c6644ff968 100644 --- a/test/correctness/host_alignment.cpp +++ b/test/correctness/host_alignment.cpp @@ -74,10 +74,9 @@ class CountHostAlignmentAsserts : public IRVisitor { left = call->args[0]; right = call->args[1]; } - const Call *reinterpret_call = left.as(); - if (!reinterpret_call || - !reinterpret_call->is_intrinsic(Call::reinterpret)) return; - Expr name = reinterpret_call->args[0]; + const Reinterpret *reinterpret = left.as(); + if (!reinterpret) return; + Expr name = reinterpret->value; const Variable *V = name.as(); string name_host_ptr = V->name; int expected_alignment = alignments_needed[name_host_ptr]; diff --git a/test/correctness/image_of_lists.cpp b/test/correctness/image_of_lists.cpp index 474770e55900..204db6cfe9ce 100644 --- a/test/correctness/image_of_lists.cpp +++ b/test/correctness/image_of_lists.cpp @@ -5,18 +5,12 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - -extern "C" DLLEXPORT std::list *list_create(int) { +extern "C" HALIDE_EXPORT_SYMBOL std::list *list_create(int) { return new std::list(); } HalideExtern_1(std::list *, list_create, int); -extern "C" DLLEXPORT std::list *list_maybe_insert(std::list *list, bool insert, int value) { +extern "C" HALIDE_EXPORT_SYMBOL std::list *list_maybe_insert(std::list *list, bool insert, int value) { if (insert) { list->push_back(value); } @@ -50,16 +44,16 @@ int main(int argc, char **argv) { // Inspect the results for correctness for (int i = 0; i < 100; i++) { std::list *list = result(i); - //printf("Factors of %d: ", i); + // printf("Factors of %d: ", i); for (std::list::iterator iter = list->begin(); iter != list->end(); iter++) { int factor = *iter; if (i % factor) { printf("Error: %d is not a factor of %d\n", factor, i); return -1; } - //printf("%d ", factor); + // printf("%d ", factor); } - //printf("\n"); + // printf("\n"); delete list; } diff --git a/test/correctness/intrinsics.cpp b/test/correctness/intrinsics.cpp index f6e24a81c877..068f360d214f 100644 --- a/test/correctness/intrinsics.cpp +++ b/test/correctness/intrinsics.cpp @@ -72,7 +72,6 @@ void check_intrinsics_over_range() { {halving_add(a_expr, b_expr), (a + b) >> 1}, {rounding_halving_add(a_expr, b_expr), (a + b + 1) >> 1}, {halving_sub(a_expr, b_expr), (a - b) >> 1}, - {rounding_halving_sub(a_expr, b_expr), (a - b + 1) >> 1}, }; for (const auto &p : intrinsics_with_reference_answer) { Expr test = lower_intrinsics(p.first); @@ -222,12 +221,6 @@ int main(int argc, char **argv) { check(u8((widening_add(u8x, u8y) + 1) / 2), rounding_halving_add(u8x, u8y)); check((i32x + i32y + 1) / 2, rounding_halving_add(i32x, i32y)); - check(i8((i16(i8x) - i8y + 1) / 2), rounding_halving_sub(i8x, i8y)); - check(u8((u16(u8x) - u8y + 1) / 2), rounding_halving_sub(u8x, u8y)); - check(i8((widening_sub(i8x, i8y) + 1) / 2), rounding_halving_sub(i8x, i8y)); - check(u8((widening_sub(u8x, u8y) + 1) / 2), rounding_halving_sub(u8x, u8y)); - check((i32x - i32y + 1) / 2, rounding_halving_sub(i32x, i32y)); - // Check absd check(abs(i16(i8x) - i16(i8y)), u16(absd(i8x, i8y))); check(abs(i16(u8x) - i16(u8y)), u16(absd(u8x, u8y))); @@ -253,7 +246,7 @@ int main(int argc, char **argv) { check(narrow((u16(u8x) + 500) >> 4), narrow((u16(u8x) + 500) >> 4)); check((u64(u32x) + 8) / 16, u64(rounding_shift_right(u32x, 4))); - check(u16(min((u64(u32x) + 8) / 16, 65535)), u16(min(rounding_shift_right(u32x, 4), 65535))); + check(u16(min((u64(u32x) + 8) / 16, 65535)), u16_sat(rounding_shift_right(u32x, 4))); // And with variable shifts. check(i8(widening_add(i8x, (i8(1) << u8y) / 2) >> u8y), rounding_shift_right(i8x, u8y)); diff --git a/test/correctness/inverse.cpp b/test/correctness/inverse.cpp index c86a4803397a..3e4a80a4afb1 100644 --- a/test/correctness/inverse.cpp +++ b/test/correctness/inverse.cpp @@ -24,7 +24,7 @@ void check(Buffer a, Buffer b) { int err = bits_diff(a(i), b(i)); if (err > 13) { printf("Mismatch in mantissa at %d: %10.10f %10.10f. Differs by %d bits.\n", i, a(i), b(i), err); - //exit(-1); + // exit(-1); } } } diff --git a/test/correctness/issue_3926.cpp b/test/correctness/issue_3926.cpp index 140af87e6e35..30d7cd0b1713 100644 --- a/test/correctness/issue_3926.cpp +++ b/test/correctness/issue_3926.cpp @@ -11,7 +11,7 @@ int main(int argc, char *argv[]) { f(x) = x; g(x, y) = f(x) + select(param, 1, 2); - //g.gpu_tile(x, y, tx, ty, 8, 8, TailStrategy::GuardWithIf); + // g.gpu_tile(x, y, tx, ty, 8, 8, TailStrategy::GuardWithIf); g.specialize(param).tile(x, y, tx, ty, 8, 8, TailStrategy::GuardWithIf); g.specialize(!param).tile(x, y, tx, ty, 8, 8, TailStrategy::GuardWithIf); g.specialize_fail("Unknown"); diff --git a/test/correctness/lazy_convolution.cpp b/test/correctness/lazy_convolution.cpp index fd8adb38d61f..d54202755648 100644 --- a/test/correctness/lazy_convolution.cpp +++ b/test/correctness/lazy_convolution.cpp @@ -3,14 +3,8 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - int call_count; -extern "C" DLLEXPORT float call_counter(float x) { +extern "C" HALIDE_EXPORT_SYMBOL float call_counter(float x) { call_count++; return x; } diff --git a/test/correctness/loop_invariant_extern_calls.cpp b/test/correctness/loop_invariant_extern_calls.cpp index 37c617e783a6..e266f477c103 100644 --- a/test/correctness/loop_invariant_extern_calls.cpp +++ b/test/correctness/loop_invariant_extern_calls.cpp @@ -5,20 +5,14 @@ using namespace Halide; // NB: You must compile with -rdynamic for llvm to be able to find the appropriate symbols -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - int call_counter[] = {0, 0, 0, 0, 0, 0}; -extern "C" DLLEXPORT int my_func(int counter, int x) { +extern "C" HALIDE_EXPORT_SYMBOL int my_func(int counter, int x) { call_counter[counter]++; return x; } HalidePureExtern_2(int, my_func, int, int); -extern "C" DLLEXPORT int my_impure_func(int counter, int x) { +extern "C" HALIDE_EXPORT_SYMBOL int my_impure_func(int counter, int x) { call_counter[counter]++; return x; } diff --git a/test/correctness/low_bit_depth_noise.cpp b/test/correctness/low_bit_depth_noise.cpp new file mode 100644 index 000000000000..d9d24d60d7b0 --- /dev/null +++ b/test/correctness/low_bit_depth_noise.cpp @@ -0,0 +1,46 @@ +#include "Halide.h" + +using namespace Halide; + +int main(int argc, char **argv) { + // Halide only provides 32-bit noise functions, which are overkill for + // generating low bit-depth noise (e.g. for dithering). This test shows how + // to generate 8-bit noise by slicing out bytes from 32-bit noise. + Var x; + + Func noise; + noise(x) = random_uint(); + + Func noise8; + noise8(x) = extract_bits(noise(x / 4), 8 * (x % 4)); + + Func in16; + in16(x) = cast(x); + + Func dithered; + dithered(x) = cast((in16(x) + noise8(x)) >> 8); + + in16.compute_root(); + dithered.compute_root().vectorize(x, 16, TailStrategy::RoundUp); + noise8.compute_at(dithered, x).vectorize(x); + + // To keep things aligned: + dithered.output_buffer().dim(0).set_min(0); + + Buffer out = dithered.realize({1 << 15}); + + uint32_t sum = 0, correct_sum = 0; + for (int i = 0; i < out.width(); i++) { + sum += out(i); + correct_sum += i; + } + correct_sum = (correct_sum + 128) >> 8; + + if (std::abs((double)sum - correct_sum) / correct_sum > 1e-4) { + printf("Suspiciously large relative difference between the sum of the dithered values and the full-precision sum: %d vs %d\n", sum, correct_sum); + return -1; + } + + printf("Success!\n"); + return 0; +} diff --git a/test/correctness/make_struct.cpp b/test/correctness/make_struct.cpp index 6147ac1790ff..4c3f06cd89ca 100644 --- a/test/correctness/make_struct.cpp +++ b/test/correctness/make_struct.cpp @@ -11,13 +11,7 @@ struct struct_t { const char *d; }; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - -extern "C" DLLEXPORT int check_struct(struct_t *s) { +extern "C" HALIDE_EXPORT_SYMBOL int check_struct(struct_t *s) { if (s->a != 3.0 || s->b != 1234567 || s->c != 1234 || diff --git a/test/correctness/many_small_extern_stages.cpp b/test/correctness/many_small_extern_stages.cpp index 15de509bc264..9b177a2aded1 100644 --- a/test/correctness/many_small_extern_stages.cpp +++ b/test/correctness/many_small_extern_stages.cpp @@ -1,19 +1,13 @@ #include "Halide.h" #include -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - void dump_buffer_shape(halide_buffer_t *b) { for (int i = 0; i < b->dimensions; i++) { printf(" %d %d %d\n", b->dim[i].min, b->dim[i].extent, b->dim[i].stride); } } -extern "C" DLLEXPORT int copy(halide_buffer_t *in, halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int copy(halide_buffer_t *in, halide_buffer_t *out) { /* printf("out:\n"); diff --git a/test/correctness/memoize.cpp b/test/correctness/memoize.cpp index 756a78a6c6a9..9acec33647aa 100644 --- a/test/correctness/memoize.cpp +++ b/test/correctness/memoize.cpp @@ -6,17 +6,11 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - // External functions to track whether the cache is working. int call_count = 0; -extern "C" DLLEXPORT int count_calls(halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int count_calls(halide_buffer_t *out) { if (!out->is_bounds_query()) { call_count++; Halide::Runtime::Buffer(*out).fill(42); @@ -26,7 +20,7 @@ extern "C" DLLEXPORT int count_calls(halide_buffer_t *out) { int call_count_with_arg = 0; -extern "C" DLLEXPORT int count_calls_with_arg(uint8_t val, halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int count_calls_with_arg(uint8_t val, halide_buffer_t *out) { if (!out->is_bounds_query()) { call_count_with_arg++; Halide::Runtime::Buffer(*out).fill(val); @@ -36,7 +30,7 @@ extern "C" DLLEXPORT int count_calls_with_arg(uint8_t val, halide_buffer_t *out) int call_count_with_arg_parallel[8]; -extern "C" DLLEXPORT int count_calls_with_arg_parallel(uint8_t val, halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int count_calls_with_arg_parallel(uint8_t val, halide_buffer_t *out) { if (!out->is_bounds_query()) { call_count_with_arg_parallel[out->dim[2].min]++; Halide::Runtime::Buffer(*out).fill(val); @@ -46,7 +40,7 @@ extern "C" DLLEXPORT int count_calls_with_arg_parallel(uint8_t val, halide_buffe int call_count_staged[4]; -extern "C" DLLEXPORT int count_calls_staged(int32_t stage, uint8_t val, halide_buffer_t *in, halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int count_calls_staged(int32_t stage, uint8_t val, halide_buffer_t *in, halide_buffer_t *out) { if (in->is_bounds_query()) { for (int i = 0; i < out->dimensions; i++) { in->dim[i] = out->dim[i]; @@ -60,7 +54,7 @@ extern "C" DLLEXPORT int count_calls_staged(int32_t stage, uint8_t val, halide_b return 0; } -extern "C" DLLEXPORT int computed_eviction_key(int a) { +extern "C" HALIDE_EXPORT_SYMBOL int computed_eviction_key(int a) { return 2020 + a; } HalideExtern_1(int, computed_eviction_key, int); diff --git a/test/correctness/memoize_cloned.cpp b/test/correctness/memoize_cloned.cpp index 6335c836eaf4..89e8161ff673 100644 --- a/test/correctness/memoize_cloned.cpp +++ b/test/correctness/memoize_cloned.cpp @@ -2,14 +2,8 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - int call_count; -extern "C" DLLEXPORT int call_counter(int x) { +extern "C" HALIDE_EXPORT_SYMBOL int call_counter(int x) { call_count++; return x; } diff --git a/test/correctness/mul_div_mod.cpp b/test/correctness/mul_div_mod.cpp index 29ce6557ea3e..af82c8bf26be 100644 --- a/test/correctness/mul_div_mod.cpp +++ b/test/correctness/mul_div_mod.cpp @@ -1,7 +1,7 @@ #include "Halide.h" +#include "test_sharding.h" #include -#include #include #include @@ -501,54 +501,45 @@ bool f_mod() { return success; } -bool test_mul(int vector_width, ScheduleVariant scheduling, Target target) { - std::cout << "Testing mul vector_width: " + std::to_string(vector_width) + "\n"; - - bool success = true; +struct Task { + std::function fn; +}; +void add_test_mul(int vector_width, ScheduleVariant scheduling, Target target, std::vector &tasks) { // Non-widening multiplication. - success &= mul(vector_width, scheduling, target); - success &= mul(vector_width, scheduling, target); - success &= mul(vector_width, scheduling, target); - success &= mul(vector_width, scheduling, target); - success &= mul(vector_width, scheduling, target); - success &= mul(vector_width, scheduling, target); + tasks.push_back({[=]() { return mul(vector_width, scheduling, target); }}); + tasks.push_back({[=]() { return mul(vector_width, scheduling, target); }}); + tasks.push_back({[=]() { return mul(vector_width, scheduling, target); }}); + tasks.push_back({[=]() { return mul(vector_width, scheduling, target); }}); + tasks.push_back({[=]() { return mul(vector_width, scheduling, target); }}); + tasks.push_back({[=]() { return mul(vector_width, scheduling, target); }}); // Widening multiplication. - success &= mul(vector_width, scheduling, target); - success &= mul(vector_width, scheduling, target); - success &= mul(vector_width, scheduling, target); - success &= mul(vector_width, scheduling, target); + tasks.push_back({[=]() { return mul(vector_width, scheduling, target); }}); + tasks.push_back({[=]() { return mul(vector_width, scheduling, target); }}); + tasks.push_back({[=]() { return mul(vector_width, scheduling, target); }}); + tasks.push_back({[=]() { return mul(vector_width, scheduling, target); }}); // Mixed multiplication. This isn't all of the possible mixed // multiplications, but it covers all of the special cases we // have in Halide. - success &= mul(vector_width, scheduling, target); - success &= mul(vector_width, scheduling, target); - success &= mul(vector_width, scheduling, target); - - return success; + tasks.push_back({[=]() { return mul(vector_width, scheduling, target); }}); + tasks.push_back({[=]() { return mul(vector_width, scheduling, target); }}); + tasks.push_back({[=]() { return mul(vector_width, scheduling, target); }}); } -bool test_div_mod(int vector_width, ScheduleVariant scheduling, Target target) { - std::cout << "Testing div_mod vector_width: " + std::to_string(vector_width) + "\n"; - - bool success = true; - - success &= div_mod(vector_width, scheduling, target); - success &= div_mod(vector_width, scheduling, target); - success &= div_mod(vector_width, scheduling, target); - success &= div_mod(vector_width, scheduling, target); - success &= div_mod(vector_width, scheduling, target); - success &= div_mod(vector_width, scheduling, target); - return success; +void add_test_div_mod(int vector_width, ScheduleVariant scheduling, Target target, std::vector &tasks) { + tasks.push_back({[=]() { return div_mod(vector_width, scheduling, target); }}); + tasks.push_back({[=]() { return div_mod(vector_width, scheduling, target); }}); + tasks.push_back({[=]() { return div_mod(vector_width, scheduling, target); }}); + tasks.push_back({[=]() { return div_mod(vector_width, scheduling, target); }}); + tasks.push_back({[=]() { return div_mod(vector_width, scheduling, target); }}); + tasks.push_back({[=]() { return div_mod(vector_width, scheduling, target); }}); } int main(int argc, char **argv) { Target target = get_jit_target_from_environment(); - bool can_parallelize = !target.has_feature(Target::OpenGLCompute); - ScheduleVariant scheduling = CPU; if (target.has_gpu_feature()) { scheduling = TiledGPU; @@ -573,52 +564,24 @@ int main(int argc, char **argv) { } } - size_t num_threads = Halide::Internal::ThreadPool::num_processors_online(); - if (target.has_feature(Target::OpenCL)) { - // TODO(https://github.com/halide/Halide/issues/5634): - // Try to track down sporadic failures of this function for OpenCL - // -- avoid running simultaneous tests - // -- set HL_DEBUG_CODEGEN so we can see what the IR looks like - num_threads = 1; -#ifdef _WIN32 - _putenv_s("HL_DEBUG_CODEGEN", "1"); -#else - setenv("HL_DEBUG_CODEGEN", "1", 1); -#endif - } - - Halide::Internal::ThreadPool pool(num_threads); - std::vector> futures; - + std::vector tasks; for (int vector_width : vector_widths) { - if (can_parallelize) { - auto f = pool.async(test_mul, vector_width, scheduling, target); - futures.push_back(std::move(f)); - } else if (!test_mul(vector_width, scheduling, target)) { - return -1; - } + add_test_mul(vector_width, scheduling, target, tasks); } - for (int vector_width : vector_widths) { - if (can_parallelize) { - auto f = pool.async(test_div_mod, vector_width, scheduling, target); - futures.push_back(std::move(f)); - } else if (!test_div_mod(vector_width, scheduling, target)) { - return -1; - } + add_test_div_mod(vector_width, scheduling, target, tasks); } - futures.push_back(pool.async(f_mod)); - - bool success = true; - for (auto &f : futures) { - success &= f.get(); + using Sharder = Halide::Internal::Test::Sharder; + Sharder sharder; + for (size_t t = 0; t < tasks.size(); t++) { + if (!sharder.should_run(t)) continue; + const auto &task = tasks.at(t); + if (!task.fn()) { + exit(-1); + } } - if (!success) { - printf("Failure!\n"); - return -1; - } printf("Success!\n"); return 0; } diff --git a/test/correctness/multi_output_pipeline_with_bad_sizes.cpp b/test/correctness/multi_output_pipeline_with_bad_sizes.cpp index 60d1fbc175cd..9be693568bfa 100644 --- a/test/correctness/multi_output_pipeline_with_bad_sizes.cpp +++ b/test/correctness/multi_output_pipeline_with_bad_sizes.cpp @@ -21,7 +21,7 @@ int main(int argc, char **argv) { f.jit_handlers().custom_error = &halide_error; error_occurred = false; - Realization r(x_out, sin_x_out); + Realization r({x_out, sin_x_out}); f.realize(r); if (!error_occurred) { diff --git a/test/correctness/multiple_outputs_extern.cpp b/test/correctness/multiple_outputs_extern.cpp index d700a36654cb..a29dbd62c049 100644 --- a/test/correctness/multiple_outputs_extern.cpp +++ b/test/correctness/multiple_outputs_extern.cpp @@ -1,13 +1,7 @@ #include "Halide.h" #include -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - -extern "C" DLLEXPORT int flip_x_and_sum(halide_buffer_t *in1, halide_buffer_t *in2, halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int flip_x_and_sum(halide_buffer_t *in1, halide_buffer_t *in2, halide_buffer_t *out) { int min = out->dim[0].min; int max = out->dim[0].min + out->dim[0].extent - 1; diff --git a/test/correctness/non_nesting_extern_bounds_query.cpp b/test/correctness/non_nesting_extern_bounds_query.cpp index 8c8df8d4c506..6577612f3361 100644 --- a/test/correctness/non_nesting_extern_bounds_query.cpp +++ b/test/correctness/non_nesting_extern_bounds_query.cpp @@ -2,12 +2,6 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - // Extern stages are supposed to obey the following nesting property // on bounds queries: If some region of the output O requires some // region of the input I, then requesting any subset of O should only @@ -22,7 +16,7 @@ using namespace Halide; // received in non-bounds-query-mode is the intersection of what it // asked for for a single scanline and what it asked for for the whole // image. -extern "C" DLLEXPORT int misbehaving_extern_stage(halide_buffer_t *in, int variant, halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int misbehaving_extern_stage(halide_buffer_t *in, int variant, halide_buffer_t *out) { if (in->is_bounds_query()) { // As a baseline, require the same amount of input as output, like a copy memcpy(in->dim, out->dim, out->dimensions * sizeof(halide_dimension_t)); diff --git a/test/correctness/parallel_fork.cpp b/test/correctness/parallel_fork.cpp index e7b588da6f74..5183370072a4 100644 --- a/test/correctness/parallel_fork.cpp +++ b/test/correctness/parallel_fork.cpp @@ -11,13 +11,7 @@ using namespace Halide::Tools; std::atomic call_count{0}; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - -extern "C" DLLEXPORT int five_ms(int arg) { +extern "C" HALIDE_EXPORT_SYMBOL int five_ms(int arg) { call_count++; std::this_thread::sleep_for(std::chrono::milliseconds(5)); return arg; diff --git a/test/correctness/param.cpp b/test/correctness/param.cpp index 19bf94b43520..eb2399fbcffd 100644 --- a/test/correctness/param.cpp +++ b/test/correctness/param.cpp @@ -24,6 +24,7 @@ int main(int argc, char **argv) { } u.set(17); + u.set_estimate(17); Buffer out_17 = f.realize({1024}, target); // verify the get method. @@ -33,6 +34,7 @@ int main(int argc, char **argv) { // so setting the copy should be equivalent to setting the original. Param u_alias = u; u_alias.set(123); + u_alias.set_estimate(123); Buffer out_123 = f.realize({1024}, target); // verify the get method, again. @@ -69,6 +71,7 @@ int main(int argc, char **argv) { // For Param, you must provide an explicit template argument to set(), // and it must match the dynamic type of the Param. u.set(17); + u.set_estimate(17); Buffer out_17 = f.realize({1024}, target); // For Param, you must provide an explicit template argument to get(), @@ -82,6 +85,7 @@ int main(int argc, char **argv) { // so setting the copy should be equivalent to setting the original. Param<> u_alias = u; u_alias.set(123); + u_alias.set_estimate(123); Buffer out_123 = f.realize({1024}, target); assert(u.get() == 123); @@ -105,12 +109,14 @@ int main(int argc, char **argv) { f(x) = u; u.set(17); + u.set_estimate(17); Buffer out_17 = f.realize({1}); assert(out_17(0) == 17); // You can always construct a Param from a Param Param<> u_alias = u; u_alias.set(123); + u_alias.set_estimate(123); Buffer out_123 = f.realize({1}); assert(out_123(0) == 123); @@ -119,6 +125,7 @@ int main(int argc, char **argv) { // of the LHS (otherwise, assert-fails) Param u_alias2 = u_alias; u_alias2.set(124); + u_alias2.set_estimate(124); Buffer out_124 = f.realize({1}); assert(out_124(0) == 124); } @@ -131,6 +138,7 @@ int main(int argc, char **argv) { f(x) = u; u.set(17); + u.set_estimate(17); Buffer out_17 = f.realize({1}); assert(out_17(0) == 17); @@ -139,6 +147,7 @@ int main(int argc, char **argv) { u_alias = u; assert(u_alias.type() == Int(32)); u_alias.set(123); + u_alias.set_estimate(123); Buffer out_123 = f.realize({1}); assert(out_123(0) == 123); @@ -148,6 +157,7 @@ int main(int argc, char **argv) { Param u_alias2; u_alias2 = u_alias; u_alias2.set(124); + u_alias2.set_estimate(124); Buffer out_124 = f.realize({1}); assert(out_124(0) == 124); } diff --git a/test/correctness/pipeline_set_jit_externs_func.cpp b/test/correctness/pipeline_set_jit_externs_func.cpp index bec64c35d4c9..65a82705ac56 100644 --- a/test/correctness/pipeline_set_jit_externs_func.cpp +++ b/test/correctness/pipeline_set_jit_externs_func.cpp @@ -3,14 +3,8 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - int call_counter = 0; -extern "C" DLLEXPORT float my_func(int x, float y) { +extern "C" HALIDE_EXPORT_SYMBOL float my_func(int x, float y) { call_counter++; return x * y; } diff --git a/test/correctness/process_some_tiles.cpp b/test/correctness/process_some_tiles.cpp index dd7a80e77b50..1aeba1f22121 100644 --- a/test/correctness/process_some_tiles.cpp +++ b/test/correctness/process_some_tiles.cpp @@ -5,14 +5,8 @@ using namespace Halide; // A version of pow that tracks usage so we can check how many times it was called. -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - int call_count = 0; -extern "C" DLLEXPORT float my_powf(float x, float y) { +extern "C" HALIDE_EXPORT_SYMBOL float my_powf(float x, float y) { call_count++; // We have to read from call_count, or for some reason apple clang removes it entirely. assert(call_count != -1); diff --git a/test/correctness/realize_condition_depends_on_tuple.cpp b/test/correctness/realize_condition_depends_on_tuple.cpp new file mode 100644 index 000000000000..c4bf0b58e9a7 --- /dev/null +++ b/test/correctness/realize_condition_depends_on_tuple.cpp @@ -0,0 +1,31 @@ +#include "Halide.h" + +using namespace Halide; + +// This is a test for a bug where the condition on a realize node didn't have +// tuple-valued calls resolved if the realization was itself tuple-valued. + +int main(int argc, char **argv) { + Func f; + Param p; + f() = {p, p}; + + Func g; + g() = {4, 4}; + + Func h; + h() = g()[1]; + + // h may or may not be necessary to evaluate, depending on a load from f, + // which means g in turn may or may not be necessary to allocate. + Func out; + out() = select(f()[1] == 3, h(), 17); + + f.compute_root(); + g.compute_root(); + h.compute_root(); + out.compile_jit(); + + printf("Success!\n"); + return 0; +} diff --git a/test/correctness/reduction_predicate_racing.cpp b/test/correctness/reduction_predicate_racing.cpp new file mode 100644 index 000000000000..0b5dd94cdd18 --- /dev/null +++ b/test/correctness/reduction_predicate_racing.cpp @@ -0,0 +1,47 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Var x; + + { + Func f; + + RDom r(1, 10); + f(x) = 1; + // this does not race, because the RDom does not contain 0 + r.where(f(0) == 1); + f(r) = 2; + + f.update().parallel(r); + } + + { + Func f; + + RDom r(0, 10); + f(x) = 1; + // this does not race, because there is no communication + r.where(f(r) == 1); + f(r) = 2; + + f.update().parallel(r); + } + + { + Func f; + + RDom r(0, 10); + f(x) = 1; + // this does not race, because there is no communication (odds vs evens) + r.where(f(2 * r) == 1); + f(2 * r + 1) = 2; + + f.update().parallel(r); + } + + printf("Success!\n"); + return 0; +} diff --git a/test/correctness/rfactor.cpp b/test/correctness/rfactor.cpp index 0cbaace189f2..ccefd4275443 100644 --- a/test/correctness/rfactor.cpp +++ b/test/correctness/rfactor.cpp @@ -1,5 +1,6 @@ #include "Halide.h" #include "check_call_graphs.h" +#include "test_sharding.h" #include #include @@ -12,7 +13,8 @@ using std::string; using namespace Halide; using namespace Halide::Internal; -int simple_rfactor_test(bool compile_module) { +template +int simple_rfactor_test() { Func f("f"), g("g"); Var x("x"), y("y"); @@ -52,7 +54,8 @@ int simple_rfactor_test(bool compile_module) { return 0; } -int reorder_split_rfactor_test(bool compile_module) { +template +int reorder_split_rfactor_test() { Func f("f"), g("g"); Var x("x"), y("y"); @@ -97,7 +100,8 @@ int reorder_split_rfactor_test(bool compile_module) { return 0; } -int multi_split_rfactor_test(bool compile_module) { +template +int multi_split_rfactor_test() { Func f("f"), g("g"); Var x("x"), y("y"); @@ -145,7 +149,8 @@ int multi_split_rfactor_test(bool compile_module) { return 0; } -int reorder_fuse_wrapper_rfactor_test(bool compile_module) { +template +int reorder_fuse_wrapper_rfactor_test() { Func f("f"), g("g"); Var x("x"), y("y"), z("z"); @@ -195,7 +200,8 @@ int reorder_fuse_wrapper_rfactor_test(bool compile_module) { return 0; } -int non_trivial_lhs_rfactor_test(bool compile_module) { +template +int non_trivial_lhs_rfactor_test() { Func a("a"), b("b"), c("c"); Var x("x"), y("y"), z("z"); @@ -265,7 +271,8 @@ int non_trivial_lhs_rfactor_test(bool compile_module) { return 0; } -int simple_rfactor_with_specialize_test(bool compile_module) { +template +int simple_rfactor_with_specialize_test() { Func f("f"), g("g"); Var x("x"), y("y"); @@ -319,7 +326,8 @@ int simple_rfactor_with_specialize_test(bool compile_module) { return 0; } -int rdom_with_predicate_rfactor_test(bool compile_module) { +template +int rdom_with_predicate_rfactor_test() { Func f("f"), g("g"); Var x("x"), y("y"), z("z"); @@ -364,7 +372,8 @@ int rdom_with_predicate_rfactor_test(bool compile_module) { return 0; } -int histogram_rfactor_test(bool compile_module) { +template +int histogram_rfactor_test() { int W = 128, H = 128; // Compute a random image and its true histogram @@ -420,7 +429,8 @@ int histogram_rfactor_test(bool compile_module) { return 0; } -int parallel_dot_product_rfactor_test(bool compile_module) { +template +int parallel_dot_product_rfactor_test() { int size = 1024; Func f("f"), g("g"), a("a"), b("b"); @@ -482,7 +492,8 @@ int parallel_dot_product_rfactor_test(bool compile_module) { return 0; } -int tuple_rfactor_test(bool compile_module) { +template +int tuple_rfactor_test() { Func f("f"), g("g"); Var x("x"), y("y"); @@ -552,7 +563,8 @@ int tuple_rfactor_test(bool compile_module) { return 0; } -int tuple_specialize_rdom_predicate_rfactor_test(bool compile_module) { +template +int tuple_specialize_rdom_predicate_rfactor_test() { Func f("f"), g("g"); Var x("x"), y("y"), z("z"); @@ -887,7 +899,8 @@ int rfactor_tile_reorder_test() { return 0; } -int tuple_partial_reduction_rfactor_test(bool compile_module) { +template +int tuple_partial_reduction_rfactor_test() { Func f("f"), g("g"); Var x("x"), y("y"); @@ -982,160 +995,54 @@ int self_assignment_rfactor_test() { } // namespace int main(int argc, char **argv) { - printf("Running self assignment rfactor test\n"); - if (self_assignment_rfactor_test() != 0) { - return -1; - } - - printf("Running simple rfactor test\n"); - printf(" checking call graphs...\n"); - if (simple_rfactor_test(true) != 0) { - return -1; - } - printf(" checking output img correctness...\n"); - if (simple_rfactor_test(false) != 0) { - return -1; - } - - printf("Running reorder split rfactor test\n"); - printf(" checking call graphs...\n"); - if (reorder_split_rfactor_test(true) != 0) { - return -1; - } - printf(" checking output img correctness...\n"); - if (reorder_split_rfactor_test(false) != 0) { - return -1; - } - - printf("Running multiple split rfactor test\n"); - printf(" checking call graphs...\n"); - if (multi_split_rfactor_test(true) != 0) { - return -1; - } - printf(" checking output img correctness...\n"); - if (multi_split_rfactor_test(false) != 0) { - return -1; - } - - printf("Running reorder fuse wrapper rfactor test\n"); - printf(" checking call graphs...\n"); - if (reorder_fuse_wrapper_rfactor_test(true) != 0) { - return -1; - } - printf(" checking output img correctness...\n"); - if (reorder_fuse_wrapper_rfactor_test(false) != 0) { - return -1; - } - - printf("Running non trivial lhs rfactor test\n"); - printf(" checking call graphs...\n"); - if (non_trivial_lhs_rfactor_test(true) != 0) { - return -1; - } - printf(" checking output img correctness...\n"); - if (non_trivial_lhs_rfactor_test(false) != 0) { - return -1; - } - - printf("Running simple rfactor with specialization test\n"); - printf(" checking call graphs...\n"); - if (simple_rfactor_with_specialize_test(true) != 0) { - return -1; - } - printf(" checking output img correctness...\n"); - if (simple_rfactor_with_specialize_test(false) != 0) { - return -1; - } - - printf("Running rdom with predicate rfactor test\n"); - printf(" checking call graphs...\n"); - if (rdom_with_predicate_rfactor_test(true) != 0) { - return -1; - } - printf(" checking output img correctness...\n"); - if (rdom_with_predicate_rfactor_test(false) != 0) { - return -1; - } - - printf("Running histogram rfactor test\n"); - printf(" checking call graphs...\n"); - if (histogram_rfactor_test(true) != 0) { - return -1; - } - printf(" checking output img correctness...\n"); - if (histogram_rfactor_test(false) != 0) { - return -1; - } - - printf("Running parallel dot product rfactor test\n"); - printf(" checking call graphs...\n"); - if (parallel_dot_product_rfactor_test(true) != 0) { - return -1; - } - printf(" checking output img correctness...\n"); - if (parallel_dot_product_rfactor_test(false) != 0) { - return -1; - } - - printf("Running tuple rfactor test\n"); - printf(" checking call graphs...\n"); - if (tuple_rfactor_test(true) != 0) { - return -1; - } - printf(" checking output img correctness...\n"); - if (tuple_rfactor_test(false) != 0) { - return -1; - } - - printf("Running tuple specialize rdom predicate rfactor test\n"); - printf(" checking call graphs...\n"); - if (tuple_specialize_rdom_predicate_rfactor_test(true) != 0) { - return -1; - } - printf(" checking output img correctness...\n"); - if (tuple_specialize_rdom_predicate_rfactor_test(false) != 0) { - return -1; - } - - printf("Running parallel dot product rfactor test\n"); - printf(" checking call graphs...\n"); - if (parallel_dot_product_rfactor_test(true) != 0) { - return -1; - } - printf(" checking output img correctness...\n"); - if (parallel_dot_product_rfactor_test(false) != 0) { - return -1; - } - - printf("Running tuple partial reduction rfactor test\n"); - printf(" checking call graphs...\n"); - if (tuple_partial_reduction_rfactor_test(true) != 0) { - return -1; - } - printf(" checking output img correctness...\n"); - if (tuple_partial_reduction_rfactor_test(false) != 0) { - return -1; - } - - printf("Running check allocation bound test\n"); - if (check_allocation_bound_test() != 0) { - return -1; - } - - printf("Running rfactor tile reorder test\n"); - printf(" checking output img correctness...\n"); - if (rfactor_tile_reorder_test() != 0) { - return -1; - } + struct Task { + std::string desc; + std::function fn; + }; - printf("Running complex multiply rfactor test\n"); - if (complex_multiply_rfactor_test() != 0) { - return -1; - } + std::vector tasks = { + {"self assignment rfactor test", self_assignment_rfactor_test}, + {"simple rfactor test: checking call graphs...", simple_rfactor_test}, + {"simple rfactor test: checking output img correctness...", simple_rfactor_test}, + {"reorder split rfactor test: checking call graphs...", reorder_split_rfactor_test}, + {"reorder split rfactor test: checking output img correctness...", reorder_split_rfactor_test}, + {"multiple split rfactor test: checking call graphs...", multi_split_rfactor_test}, + {"multiple split rfactor test: checking output img correctness...", multi_split_rfactor_test}, + {"reorder fuse wrapper rfactor test: checking call graphs...", reorder_fuse_wrapper_rfactor_test}, + {"reorder fuse wrapper rfactor test: checking output img correctness...", reorder_fuse_wrapper_rfactor_test}, + {"non trivial lhs rfactor test: checking call graphs...", non_trivial_lhs_rfactor_test}, + {"non trivial lhs rfactor test: checking output img correctness...", non_trivial_lhs_rfactor_test}, + {"simple rfactor with specialization test: checking call graphs...", simple_rfactor_with_specialize_test}, + {"simple rfactor with specialization test: checking output img correctness...", simple_rfactor_with_specialize_test}, + {"rdom with predicate rfactor test: checking call graphs...", rdom_with_predicate_rfactor_test}, + {"rdom with predicate rfactor test: checking output img correctness...", rdom_with_predicate_rfactor_test}, + {"histogram rfactor test: checking call graphs...", histogram_rfactor_test}, + {"histogram rfactor test: checking output img correctness...", histogram_rfactor_test}, + {"parallel dot product rfactor test: checking call graphs...", parallel_dot_product_rfactor_test}, + {"parallel dot product rfactor test: checking output img correctness...", parallel_dot_product_rfactor_test}, + {"tuple rfactor test: checking call graphs...", tuple_rfactor_test}, + {"tuple rfactor test: checking output img correctness...", tuple_rfactor_test}, + {"tuple specialize rdom predicate rfactor test: checking call graphs...", tuple_specialize_rdom_predicate_rfactor_test}, + {"tuple specialize rdom predicate rfactor test: checking output img correctness...", tuple_specialize_rdom_predicate_rfactor_test}, + {"parallel dot product rfactor test: checking call graphs...", parallel_dot_product_rfactor_test}, + {"parallel dot product rfactor test: checking output img correctness...", parallel_dot_product_rfactor_test}, + {"tuple partial reduction rfactor test: checking call graphs...", tuple_partial_reduction_rfactor_test}, + {"tuple partial reduction rfactor test: checking output img correctness...", tuple_partial_reduction_rfactor_test}, + {"check allocation bound test", check_allocation_bound_test}, + {"rfactor tile reorder test: checking output img correctness...", rfactor_tile_reorder_test}, + {"complex multiply rfactor test", complex_multiply_rfactor_test}, + {"argmin rfactor test", argmin_rfactor_test}, + }; - printf("Running argmin rfactor test\n"); - if (argmin_rfactor_test() != 0) { - return -1; + using Sharder = Halide::Internal::Test::Sharder; + Sharder sharder; + for (size_t t = 0; t < tasks.size(); t++) { + if (!sharder.should_run(t)) continue; + const auto &task = tasks.at(t); + std::cout << task.desc << "\n"; + if (task.fn() != 0) { + return -1; + } } printf("Success!\n"); diff --git a/test/correctness/shift_by_unsigned_negated.cpp b/test/correctness/shift_by_unsigned_negated.cpp new file mode 100644 index 000000000000..a8ca8dca8e06 --- /dev/null +++ b/test/correctness/shift_by_unsigned_negated.cpp @@ -0,0 +1,46 @@ +#include "Halide.h" + +using namespace Halide; + +template +bool test(Func f, T f_expected, int width) { + Buffer actual = f.realize({width}); + for (int i = 0; i < actual.width(); i++) { + if (actual(i) != f_expected(i)) { + printf("r(%d) = %d, f_expected(%d) = %d\n", + i, actual(i), i, f_expected(i)); + return false; + } + } + return true; +} + +int main(int argc, char **argv) { + Buffer step(31); + for (int i = 0; i < step.width(); i++) { + step(i) = -i; + } + + bool success = true; + Var x; + + { + Func f; + f(x) = Expr(-1U) << -step(x); + auto f_expected = [&](int x) { + return -1U << x; + }; + success &= test(f, f_expected, step.width()); + } + { + Func f; + f(x) = Expr(-1U) >> -step(x); + auto f_expected = [&](int x) { + return -1U >> x; + }; + success &= test(f, f_expected, step.width()); + } + + if (success) printf("Success!\n"); + return success ? 0 : -1; +} diff --git a/test/correctness/side_effects.cpp b/test/correctness/side_effects.cpp index 5620a5d6f380..adc90173dd59 100644 --- a/test/correctness/side_effects.cpp +++ b/test/correctness/side_effects.cpp @@ -15,14 +15,8 @@ using namespace Halide; // thread-safe. // Here we use an extern call to print an ascii-art Mandelbrot set. -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - int call_count = 0; -extern "C" DLLEXPORT int draw_pixel(int x, int y, int val) { +extern "C" HALIDE_EXPORT_SYMBOL int draw_pixel(int x, int y, int val) { call_count++; static int last_y = 0; if (y != last_y) { diff --git a/test/correctness/simd_op_check.cpp b/test/correctness/simd_op_check.cpp index a6760f6a290b..ba3e24d6e80c 100644 --- a/test/correctness/simd_op_check.cpp +++ b/test/correctness/simd_op_check.cpp @@ -158,9 +158,14 @@ class SimdOpCheck : public SimdOpCheckTest { check("pavgw", 4 * w, u16((u32(u16_1) + u32(u16_2) + 1) / 2)); check("pavgw", 4 * w, u16((u32(u16_1) + u32(u16_2) + 1) >> 1)); - // Rounding right shifts should also use pavg + // Rounding right shifts, halving subtracts, and signed rounding + // averages should also use pavg check("pavgb", 8 * w, rounding_shift_right(u8_1, 2)); check("pavgw", 4 * w, rounding_shift_right(u16_1, 2)); + check("pavgb", 8 * w, halving_sub(u8_1, u8_2)); + check("pavgw", 4 * w, halving_sub(u16_1, u16_2)); + check("pavgb", 8 * w, rounding_halving_add(i8_1, i8_2)); + check("pavgw", 4 * w, rounding_halving_add(i16_1, i16_2)); check("pmaxsw", 4 * w, max(i16_1, i16_2)); check("pminsw", 4 * w, min(i16_1, i16_2)); @@ -237,6 +242,18 @@ class SimdOpCheck : public SimdOpCheckTest { check(std::string("packssdw") + check_suffix, 4 * w, i16_sat(i32_1)); check(std::string("packsswb") + check_suffix, 8 * w, i8_sat(i16_1)); check(std::string("packuswb") + check_suffix, 8 * w, u8_sat(i16_1)); + + // Sum-of-absolute-difference ops + { + const int f = 8; // reduction factor. + RDom r(0, f); + check("psadbw", w, sum(u64(absd(in_u8(f * x + r), in_u8(f * x + r + 32))))); + check("psadbw", w, sum(u32(absd(in_u8(f * x + r), in_u8(f * x + r + 32))))); + check("psadbw", w, sum(u16(absd(in_u8(f * x + r), in_u8(f * x + r + 32))))); + check("psadbw", w, sum(i64(absd(in_u8(f * x + r), in_u8(f * x + r + 32))))); + check("psadbw", w, sum(i32(absd(in_u8(f * x + r), in_u8(f * x + r + 32))))); + check("psadbw", w, sum(i16(absd(in_u8(f * x + r), in_u8(f * x + r + 32))))); + } } // SSE 3 / SSSE 3 @@ -303,6 +320,11 @@ class SimdOpCheck : public SimdOpCheckTest { RDom r2(0, 2); check(check_pmaddubsw, 4 * w, saturating_sum(i16(in_u8(2 * x + r2)) * in_i8(2 * x + r2 + 32))); check(check_pmaddubsw, 4 * w, saturating_sum(i16(in_i8(2 * x + r2)) * in_u8(2 * x + r2 + 32))); + + // uint8 -> uint16 or int16 and int8 -> int16 horizontal widening adds should use pmaddubsw. + check(check_pmaddubsw, 4 * w, sum(u16(in_u8(2 * x + r2)))); + check(check_pmaddubsw, 4 * w, sum(i16(in_u8(2 * x + r2)))); + check(check_pmaddubsw, 4 * w, sum(i16(in_i8(2 * x + r2)))); } } @@ -318,6 +340,9 @@ class SimdOpCheck : public SimdOpCheckTest { // And also for dot-products RDom r4(0, 4); check(check_pmaddwd, 2 * w, sum(i32(in_i16(x * 4 + r4)) * in_i16(x * 4 + r4 + 32))); + + // Also generate for widening_mul + check(check_pmaddwd, 2 * w, i32(i16_1) * i32(i16_2)); } // llvm doesn't distinguish between signed and unsigned multiplies @@ -505,6 +530,18 @@ class SimdOpCheck : public SimdOpCheckTest { check("vpcmpeqq*ymm", 4, select(i64_1 == i64_2, i64(1), i64(2))); check("vpackusdw*ymm", 16, u16(clamp(i32_1, 0, max_u16))); check("vpcmpgtq*ymm", 4, select(i64_1 > i64_2, i64(1), i64(2))); + + // Sum-of-absolute-difference ops + for (int w : {4, 8}) { + const int f = 8; // reduction factor. + RDom r(0, f); + check("vpsadbw", w, sum(u64(absd(in_u8(f * x + r), in_u8(f * x + r + 32))))); + check("vpsadbw", w, sum(u32(absd(in_u8(f * x + r), in_u8(f * x + r + 32))))); + check("vpsadbw", w, sum(u16(absd(in_u8(f * x + r), in_u8(f * x + r + 32))))); + check("vpsadbw", w, sum(i64(absd(in_u8(f * x + r), in_u8(f * x + r + 32))))); + check("vpsadbw", w, sum(i32(absd(in_u8(f * x + r), in_u8(f * x + r + 32))))); + check("vpsadbw", w, sum(i16(absd(in_u8(f * x + r), in_u8(f * x + r + 32))))); + } } if (use_avx512) { @@ -1436,12 +1473,22 @@ class SimdOpCheck : public SimdOpCheckTest { check(arm32 ? "vshr.u64" : "ushr", 2 * w, u64_1 / 16); // VSHRN I - Shift Right Narrow - check(arm32 ? "vshrn.i16" : "shrn", 8 * w, i8(i16_1 / 256)); - check(arm32 ? "vshrn.i32" : "shrn", 4 * w, i16(i32_1 / 65536)); - check(arm32 ? "vshrn.i64" : "shrn", 2 * w, i32(i64_1 >> 32)); - check(arm32 ? "vshrn.i16" : "shrn", 8 * w, u8(u16_1 / 256)); - check(arm32 ? "vshrn.i32" : "shrn", 4 * w, u16(u32_1 / 65536)); - check(arm32 ? "vshrn.i64" : "shrn", 2 * w, u32(u64_1 >> 32)); + // LLVM15 emits UZP2 if the shift amount is half the width of the vector element. + const auto shrn_or_uzp2 = [&](int element_width, int shift_amt, int vector_width) { + constexpr int simd_vector_bits = 128; + if (Halide::Internal::get_llvm_version() >= 150 && + ((vector_width * element_width) % (simd_vector_bits * 2)) == 0 && + shift_amt == element_width / 2) { + return "uzp2"; + } + return "shrn"; + }; + check(arm32 ? "vshrn.i16" : shrn_or_uzp2(16, 8, 8 * w), 8 * w, i8(i16_1 / 256)); + check(arm32 ? "vshrn.i32" : shrn_or_uzp2(32, 16, 4 * w), 4 * w, i16(i32_1 / 65536)); + check(arm32 ? "vshrn.i64" : shrn_or_uzp2(64, 32, 2 * w), 2 * w, i32(i64_1 >> 32)); + check(arm32 ? "vshrn.i16" : shrn_or_uzp2(16, 8, 8 * w), 8 * w, u8(u16_1 / 256)); + check(arm32 ? "vshrn.i32" : shrn_or_uzp2(32, 16, 4 * w), 4 * w, u16(u32_1 / 65536)); + check(arm32 ? "vshrn.i64" : shrn_or_uzp2(64, 32, 2 * w), 2 * w, u32(u64_1 >> 32)); check(arm32 ? "vshrn.i16" : "shrn", 8 * w, i8(i16_1 / 16)); check(arm32 ? "vshrn.i32" : "shrn", 4 * w, i16(i32_1 / 16)); check(arm32 ? "vshrn.i64" : "shrn", 2 * w, i32(i64_1 / 16)); @@ -1766,27 +1813,17 @@ class SimdOpCheck : public SimdOpCheckTest { if (use_wasm_simd128) { for (int w = 1; w <= 4; w <<= 1) { // create arbitrary 16-byte constant - if (Halide::Internal::get_llvm_version() >= 130) { - check("v128.const", 16 * w, u8_1 * u8(42 + x)); - } + check("v128.const", 16 * w, u8_1 * u8(42 + x)); // Create vector with identical lanes // (Note that later LLVMs will use 64-bit constants for some smaller splats) check("i8x16.splat", 16 * w, u8_1 * u8(42)); - if (Halide::Internal::get_llvm_version() >= 130) { - // LLVM13 likes to emit all of these as v128.const - check("v128.const", 8 * w, u16_1 * u16(42)); - check("v128.const", 4 * w, u32_1 * u32(42)); - check("v128.const", 2 * w, u64_1 * u64(42)); - check("v128.const", 8 * w, f32_1 * f32(42)); - check("v128.const", 4 * w, f64_1 * f64(42)); - } else { - check("i64x2.splat", 8 * w, u16_1 * u16(42)); - check("i64x2.splat", 4 * w, u32_1 * u32(42)); - check("i64x2.splat", 2 * w, u64_1 * u64(42)); - check("f32x4.splat", 8 * w, f32_1 * f32(42)); - check("f64x2.splat", 4 * w, f64_1 * f64(42)); - } + // LLVM13 likes to emit all of these as v128.const + check("v128.const", 8 * w, u16_1 * u16(42)); + check("v128.const", 4 * w, u32_1 * u32(42)); + check("v128.const", 2 * w, u64_1 * u64(42)); + check("v128.const", 8 * w, f32_1 * f32(42)); + check("v128.const", 4 * w, f64_1 * f64(42)); // Extract lane as a scalar (extract_lane) // Replace lane value (replace_lane) @@ -1826,13 +1863,11 @@ class SimdOpCheck : public SimdOpCheckTest { check("i32x4.mul", 4 * w, i32_1 * i32_2); check("i64x2.mul", 2 * w, i64_1 * i64_2); - if (Halide::Internal::get_llvm_version() >= 130) { - // Integer dot product (16 -> 32) - for (int f : {2, 4, 8}) { - RDom r(0, f); - for (int v : {1, 2, 4}) { - check("i32x4.dot_i16x8_s", w * v, sum(i32(in_i16(f * x + r)) * in_i16(f * x + r + 32))); - } + // Integer dot product (16 -> 32) + for (int f : {2, 4, 8}) { + RDom r(0, f); + for (int v : {1, 2, 4}) { + check("i32x4.dot_i16x8_s", w * v, sum(i32(in_i16(f * x + r)) * in_i16(f * x + r + 32))); } } @@ -1842,72 +1877,65 @@ class SimdOpCheck : public SimdOpCheckTest { check("i32x4.neg", 4 * w, -i32_1); check("i64x2.neg", 2 * w, -i64_1); - if (Halide::Internal::get_llvm_version() >= 130) { - // At present, we only attempt to generate these for LLVM >= 13. - - // Extended (widening) integer multiplication - if (w > 1) { - // Need a register wider than 128 bits for us to generate these - check("i16x8.extmul_low_i8x16_s", 8 * w, i16(i8_1) * i8_2); - check("i32x4.extmul_low_i16x8_s", 4 * w, i32(i16_1) * i16_2); - check("i64x2.extmul_low_i32x4_s", 2 * w, i64(i32_1) * i32_2); - check("i16x8.extmul_low_i8x16_u", 8 * w, u16(u8_1) * u8_2); - check("i32x4.extmul_low_i16x8_u", 4 * w, u32(u16_1) * u16_2); - check("i64x2.extmul_low_i32x4_u", 2 * w, u64(u32_1) * u32_2); - check("i16x8.extmul_high_i8x16_s", 8 * w, i16(i8_1) * i8_2); - check("i32x4.extmul_high_i16x8_s", 4 * w, i32(i16_1) * i16_2); - check("i64x2.extmul_high_i32x4_s", 2 * w, i64(i32_1) * i32_2); - check("i16x8.extmul_high_i8x16_u", 8 * w, u16(u8_1) * u8_2); - check("i32x4.extmul_high_i16x8_u", 4 * w, u32(u16_1) * u16_2); - check("i64x2.extmul_high_i32x4_u", 2 * w, u64(u32_1) * u32_2); - } + // Extended (widening) integer multiplication + if (w > 1) { + // Need a register wider than 128 bits for us to generate these + check("i16x8.extmul_low_i8x16_s", 8 * w, i16(i8_1) * i8_2); + check("i32x4.extmul_low_i16x8_s", 4 * w, i32(i16_1) * i16_2); + check("i64x2.extmul_low_i32x4_s", 2 * w, i64(i32_1) * i32_2); + check("i16x8.extmul_low_i8x16_u", 8 * w, u16(u8_1) * u8_2); + check("i32x4.extmul_low_i16x8_u", 4 * w, u32(u16_1) * u16_2); + check("i64x2.extmul_low_i32x4_u", 2 * w, u64(u32_1) * u32_2); + check("i16x8.extmul_high_i8x16_s", 8 * w, i16(i8_1) * i8_2); + check("i32x4.extmul_high_i16x8_s", 4 * w, i32(i16_1) * i16_2); + check("i64x2.extmul_high_i32x4_s", 2 * w, i64(i32_1) * i32_2); + check("i16x8.extmul_high_i8x16_u", 8 * w, u16(u8_1) * u8_2); + check("i32x4.extmul_high_i16x8_u", 4 * w, u32(u16_1) * u16_2); + check("i64x2.extmul_high_i32x4_u", 2 * w, u64(u32_1) * u32_2); + } - // Extended pairwise integer addition - for (int f : {2, 4}) { - RDom r(0, f); + // Extended pairwise integer addition + for (int f : {2, 4}) { + RDom r(0, f); - // A summation reduction that starts at something - // non-trivial, to avoid llvm simplifying accumulating - // widening summations into just widening summations. - auto sum_ = [&](Expr e) { - Func f; - f(x) = cast(e.type(), 123); - f(x) += e; - return f(x); - }; - - check("i16x8.extadd_pairwise_i8x16_s", 8 * w, sum_(i16(in_i8(f * x + r)))); - check("i16x8.extadd_pairwise_i8x16_u", 8 * w, sum_(u16(in_u8(f * x + r)))); - // The u8->i16 op uses the unsigned variant - check("i16x8.extadd_pairwise_i8x16_u", 8 * w, sum_(i16(in_u8(f * x + r)))); - - check("i32x4.extadd_pairwise_i16x8_s", 8 * w, sum_(i32(in_i16(f * x + r)))); - check("i32x4.extadd_pairwise_i16x8_u", 8 * w, sum_(u32(in_u16(f * x + r)))); - // The u16->i32 op uses the unsigned variant - check("i32x4.extadd_pairwise_i16x8_u", 8 * w, sum_(i32(in_u16(f * x + r)))); - } + // A summation reduction that starts at something + // non-trivial, to avoid llvm simplifying accumulating + // widening summations into just widening summations. + auto sum_ = [&](Expr e) { + Func f; + f(x) = cast(e.type(), 123); + f(x) += e; + return f(x); + }; + + check("i16x8.extadd_pairwise_i8x16_s", 8 * w, sum_(i16(in_i8(f * x + r)))); + check("i16x8.extadd_pairwise_i8x16_u", 8 * w, sum_(u16(in_u8(f * x + r)))); + // The u8->i16 op uses the unsigned variant + check("i16x8.extadd_pairwise_i8x16_u", 8 * w, sum_(i16(in_u8(f * x + r)))); + + check("i32x4.extadd_pairwise_i16x8_s", 8 * w, sum_(i32(in_i16(f * x + r)))); + check("i32x4.extadd_pairwise_i16x8_u", 8 * w, sum_(u32(in_u16(f * x + r)))); + // The u16->i32 op uses the unsigned variant + check("i32x4.extadd_pairwise_i16x8_u", 8 * w, sum_(i32(in_u16(f * x + r)))); } // Saturating integer addition - std::string sat = Halide::Internal::get_llvm_version() >= 130 ? "sat" : "saturate"; - check("i8x16.add_" + sat + "_s", 16 * w, i8_sat(i16(i8_1) + i16(i8_2))); - check("i8x16.add_" + sat + "_u", 16 * w, u8_sat(u16(u8_1) + u16(u8_2))); - check("i16x8.add_" + sat + "_s", 8 * w, i16_sat(i32(i16_1) + i32(i16_2))); - check("i16x8.add_" + sat + "_u", 8 * w, u16_sat(u32(u16_1) + u32(u16_2))); + check("i8x16.add_sat_s", 16 * w, i8_sat(i16(i8_1) + i16(i8_2))); + check("i8x16.add_sat_u", 16 * w, u8_sat(u16(u8_1) + u16(u8_2))); + check("i16x8.add_sat_s", 8 * w, i16_sat(i32(i16_1) + i32(i16_2))); + check("i16x8.add_sat_u", 8 * w, u16_sat(u32(u16_1) + u32(u16_2))); // Saturating integer subtraction - check("i8x16.sub_" + sat + "_s", 16 * w, i8_sat(i16(i8_1) - i16(i8_2))); - check("i16x8.sub_" + sat + "_s", 8 * w, i16_sat(i32(i16_1) - i32(i16_2))); + check("i8x16.sub_sat_s", 16 * w, i8_sat(i16(i8_1) - i16(i8_2))); + check("i16x8.sub_sat_s", 8 * w, i16_sat(i32(i16_1) - i32(i16_2))); // N.B. Saturating subtracts are expressed by widening to a *signed* type - check("i8x16.sub_" + sat + "_u", 16 * w, u8_sat(i16(u8_1) - i16(u8_2))); - check("i16x8.sub_" + sat + "_u", 8 * w, u16_sat(i32(u16_1) - i32(u16_2))); - - if (Halide::Internal::get_llvm_version() >= 130) { - // Saturating integer Q-format rounding multiplication - // Note: division in Halide always rounds down (not towards - // zero). Otherwise these patterns would be more complicated. - check("i16x8.q15mulr_sat_s", 8 * w, i16_sat((i32(i16_1) * i32(i16_2) + (1 << 14)) / (1 << 15))); - } + check("i8x16.sub_sat_u", 16 * w, u8_sat(i16(u8_1) - i16(u8_2))); + check("i16x8.sub_sat_u", 8 * w, u16_sat(i32(u16_1) - i32(u16_2))); + + // Saturating integer Q-format rounding multiplication + // Note: division in Halide always rounds down (not towards + // zero). Otherwise these patterns would be more complicated. + check("i16x8.q15mulr_sat_s", 8 * w, i16_sat((i32(i16_1) * i32(i16_2) + (1 << 14)) / (1 << 15))); // Lane-wise integer minimum check("i8x16.min_s", 16 * w, min(i8_1, i8_2)); @@ -1935,9 +1963,7 @@ class SimdOpCheck : public SimdOpCheckTest { check("i8x16.abs", 16 * w, abs(i8_1)); check("i16x8.abs", 8 * w, abs(i16_1)); check("i32x4.abs", 4 * w, abs(i32_1)); - if (Halide::Internal::get_llvm_version() >= 130) { - check("i64x2.abs", 2 * w, abs(i64_1)); - } + check("i64x2.abs", 2 * w, abs(i64_1)); // Left shift by constant scalar check("i8x16.shl", 16 * w, i8_1 << i8(7)); @@ -2046,9 +2072,7 @@ class SimdOpCheck : public SimdOpCheckTest { check("i8x16.eq", 16 * w, i8_1 == i8_2); check("i16x8.eq", 8 * w, i16_1 == i16_2); check("i32x4.eq", 4 * w, i32_1 == i32_2); - if (Halide::Internal::get_llvm_version() >= 130) { - check("i64x2.eq", 2 * w, i64_1 == i64_2); - } + check("i64x2.eq", 2 * w, i64_1 == i64_2); check("f32x4.eq", 4 * w, f32_1 == f32_2); check("f64x2.eq", 2 * w, f64_1 == f64_2); @@ -2056,9 +2080,7 @@ class SimdOpCheck : public SimdOpCheckTest { check("i8x16.ne", 16 * w, i8_1 != i8_2); check("i16x8.ne", 8 * w, i16_1 != i16_2); check("i32x4.ne", 4 * w, i32_1 != i32_2); - if (Halide::Internal::get_llvm_version() >= 130) { - check("i64x2.ne", 2 * w, i64_1 != i64_2); - } + check("i64x2.ne", 2 * w, i64_1 != i64_2); check("f32x4.ne", 4 * w, f32_1 != f32_2); check("f64x2.ne", 2 * w, f64_1 != f64_2); @@ -2069,9 +2091,7 @@ class SimdOpCheck : public SimdOpCheckTest { check("i16x8.lt_u", 8 * w, u16_1 < u16_2); check("i32x4.lt_s", 4 * w, i32_1 < i32_2); check("i32x4.lt_u", 4 * w, u32_1 < u32_2); - if (Halide::Internal::get_llvm_version() >= 130) { - check("i64x2.lt_s", 2 * w, i64_1 < i64_2); - } + check("i64x2.lt_s", 2 * w, i64_1 < i64_2); check("f32x4.lt", 4 * w, f32_1 < f32_2); check("f64x2.lt", 2 * w, f64_1 < f64_2); @@ -2082,9 +2102,7 @@ class SimdOpCheck : public SimdOpCheckTest { check("i16x8.le_u", 8 * w, u16_1 <= u16_2); check("i32x4.le_s", 4 * w, i32_1 <= i32_2); check("i32x4.le_u", 4 * w, u32_1 <= u32_2); - if (Halide::Internal::get_llvm_version() >= 130) { - check("i64x2.le_s", 2 * w, i64_1 <= i64_2); - } + check("i64x2.le_s", 2 * w, i64_1 <= i64_2); check("f32x4.le", 4 * w, f32_1 <= f32_2); check("f64x2.le", 2 * w, f64_1 <= f64_2); @@ -2175,23 +2193,21 @@ class SimdOpCheck : public SimdOpCheckTest { check("f32x4.sqrt", 4 * w, sqrt(f32_1)); check("f64x2.sqrt", 2 * w, sqrt(f64_1)); - if (Halide::Internal::get_llvm_version() >= 130) { - // Round to integer above (ceiling) - check("f32x4.ceil", 4 * w, ceil(f32_1)); - check("f64x2.ceil", 2 * w, ceil(f64_1)); + // Round to integer above (ceiling) + check("f32x4.ceil", 4 * w, ceil(f32_1)); + check("f64x2.ceil", 2 * w, ceil(f64_1)); - // Round to integer below (floor) - check("f32x4.floor", 4 * w, floor(f32_1)); - check("f64x2.floor", 2 * w, floor(f64_1)); + // Round to integer below (floor) + check("f32x4.floor", 4 * w, floor(f32_1)); + check("f64x2.floor", 2 * w, floor(f64_1)); - // Round to integer toward zero (truncate to integer) - check("f32x4.trunc", 4 * w, trunc(f32_1)); - check("f64x2.trunc", 2 * w, trunc(f64_1)); + // Round to integer toward zero (truncate to integer) + check("f32x4.trunc", 4 * w, trunc(f32_1)); + check("f64x2.trunc", 2 * w, trunc(f64_1)); - // Round to nearest integer, ties to even) - check("f32x4.nearest", 4 * w, round(f32_1)); - check("f64x2.nearest", 2 * w, round(f64_1)); - } + // Round to nearest integer, ties to even) + check("f32x4.nearest", 4 * w, round(f32_1)); + check("f64x2.nearest", 2 * w, round(f64_1)); // Integer to single-precision floating point check("f32x4.convert_i32x4_s", 8 * w, cast(i32_1)); @@ -2222,17 +2238,15 @@ class SimdOpCheck : public SimdOpCheckTest { if (w < 2) { check("f64x2.promote_low_f32x4", 2 * w, cast(f32_1)); } - } else if (Halide::Internal::get_llvm_version() >= 130) { + } else { check("f64x2.promote_low_f32x4", 2 * w, cast(f32_1)); } // Integer to integer narrowing - if (Halide::Internal::get_llvm_version() >= 130) { - check("i8x16.narrow_i16x8_s", 16 * w, i8_sat(i16_1)); - check("i8x16.narrow_i16x8_u", 16 * w, u8_sat(i16_1)); - check("i16x8.narrow_i32x4_s", 8 * w, i16_sat(i32_1)); - check("i16x8.narrow_i32x4_u", 8 * w, u16_sat(i32_1)); - } + check("i8x16.narrow_i16x8_s", 16 * w, i8_sat(i16_1)); + check("i8x16.narrow_i16x8_u", 16 * w, u8_sat(i16_1)); + check("i16x8.narrow_i32x4_s", 8 * w, i16_sat(i32_1)); + check("i16x8.narrow_i32x4_u", 8 * w, u16_sat(i32_1)); // Integer to integer widening check("i16x8.extend_low_i8x16_s", 16 * w, i16(i8_1)); @@ -2277,7 +2291,6 @@ int main(int argc, char **argv) { if (argc > 1) { test.filter = argv[1]; - test.set_num_threads(1); } if (getenv("HL_SIMD_OP_CHECK_FILTER")) { @@ -2288,20 +2301,6 @@ int main(int argc, char **argv) { std::cout << "simd_op_check test seed: " << seed << "\n"; test.set_seed(seed); - // TODO: multithreading here is the cause of https://github.com/halide/Halide/issues/3669; - // the fundamental issue is that we make one set of ImageParams to construct many - // Exprs, then realize those Exprs on arbitrary threads; it is known that sharing - // one Func across multiple threads is not guaranteed to be safe, and indeed, TSAN - // reports data races, of which some are likely 'benign' (e.g. Function.freeze) but others - // are highly suspect (e.g. Function.lock_loop_levels). Since multithreading here - // was added just to avoid having this test be the last to finish, the expedient 'fix' - // for now is to remove the multithreading. A proper fix could be made by restructuring this - // test so that every Expr constructed for testing was guaranteed to share no Funcs - // (Function.deep_copy() perhaps). Of course, it would also be desirable to allow Funcs, Exprs, etc - // to be usable across multiple threads, but that is a major undertaking that is - // definitely not worthwhile for present Halide usage patterns. - test.set_num_threads(1); - if (argc > 2) { // Don't forget: if you want to run the standard tests to a specific output // directory, you'll need to invoke with the first arg enclosed diff --git a/test/correctness/simd_op_check.h b/test/correctness/simd_op_check.h index 5ad997542d53..f4ac0712ec29 100644 --- a/test/correctness/simd_op_check.h +++ b/test/correctness/simd_op_check.h @@ -3,6 +3,7 @@ #include "Halide.h" #include "halide_test_dirs.h" +#include "test_sharding.h" #include @@ -46,14 +47,14 @@ class SimdOpCheckTest { int W; int H; + using Sharder = Halide::Internal::Test::Sharder; + SimdOpCheckTest(const Target t, int w, int h) : target(t), W(w), H(h) { target = target .with_feature(Target::NoBoundsQuery) .with_feature(Target::NoAsserts) - .with_feature(Target::NoRuntime) - .with_feature(Target::DisableLLVMLoopOpt); - num_threads = Internal::ThreadPool::num_processors_online(); + .with_feature(Target::NoRuntime); } virtual ~SimdOpCheckTest() = default; @@ -61,14 +62,6 @@ class SimdOpCheckTest { rng.seed(seed); } - size_t get_num_threads() const { - return num_threads; - } - - void set_num_threads(size_t n) { - num_threads = n; - } - virtual bool can_run_code() const { // Assume we are configured to run wasm if requested // (we'll fail further downstream if not) @@ -321,17 +314,13 @@ class SimdOpCheckTest { virtual bool test_all() { /* First add some tests based on the target */ add_tests(); - Internal::ThreadPool pool(num_threads); - std::vector> futures; - for (const Task &task : tasks) { - futures.push_back(pool.async([this, task]() { - return check_one(task.op, task.name, task.vector_width, task.expr); - })); - } + Sharder sharder; bool success = true; - for (auto &f : futures) { - const TestResult &result = f.get(); + for (size_t t = 0; t < tasks.size(); t++) { + if (!sharder.should_run(t)) continue; + const auto &task = tasks.at(t); + auto result = check_one(task.op, task.name, task.vector_width, task.expr); std::cout << result.op << "\n"; if (!result.error_msg.empty()) { std::cerr << result.error_msg; @@ -343,8 +332,9 @@ class SimdOpCheckTest { } private: - size_t num_threads; const Halide::Var x{"x"}, y{"y"}; }; + } // namespace Halide + #endif // SIMD_OP_CHECK_H diff --git a/test/correctness/simd_op_check_hvx.cpp b/test/correctness/simd_op_check_hvx.cpp index 8b1b42e94eda..fc35840bcaad 100644 --- a/test/correctness/simd_op_check_hvx.cpp +++ b/test/correctness/simd_op_check_hvx.cpp @@ -290,6 +290,8 @@ class SimdOpCheckHVX : public SimdOpCheckTest { check("vdelta(v*,v*)", hvx_width / 2, in_u32(3 * x / 2)); check("vdelta(v*,v*)", hvx_width * 3, in_u16(x * 3)); check("vdelta(v*,v*)", hvx_width * 3, in_u8(x * 3)); + check("vdelta(v*,v*)", hvx_width * 4, in_u16(x * 4)); + check("vdelta(v*,v*)", hvx_width * 4, in_u8(x * 4)); check("vlut32(v*.b,v*.b,r*)", hvx_width / 1, in_u8(u8_1)); check("vlut32(v*.b,v*.b,r*)", hvx_width / 1, in_u8(clamp(u16_1, 0, 63))); @@ -693,9 +695,9 @@ class SimdOpCheckHVX : public SimdOpCheckTest { check("v*:*.h += vtmpy(v*:*.ub, r*.b)", hvx_width, sum(i16(in_u8(x + r3)))); check("v*:*.w += vtmpy(v*:*.h, r*.b)", hvx_width, sum(i32(in_i16(x + r3)))); // TODO: This should work, a common stencil - //check("v*:*.h += vtmpy(v*:*.b, r*.b)", hvx_width, sum(i16(in_i8(x + r3)) * mux(r3, {1, 2, 1}))); - //check("v*:*.h += vtmpy(v*:*.ub, r*.b)", hvx_width, sum(i16(in_u8(x + r3)) * mux(r3, {1, 2, 1}))); - //check("v*:*.w += vtmpy(v*:*.h, r*.b)", hvx_width, sum(i32(in_i16(x + r3)) * mux(r3, {1, 2, 1}))); + // check("v*:*.h += vtmpy(v*:*.b, r*.b)", hvx_width, sum(i16(in_i8(x + r3)) * mux(r3, {1, 2, 1}))); + // check("v*:*.h += vtmpy(v*:*.ub, r*.b)", hvx_width, sum(i16(in_u8(x + r3)) * mux(r3, {1, 2, 1}))); + // check("v*:*.w += vtmpy(v*:*.h, r*.b)", hvx_width, sum(i32(in_i16(x + r3)) * mux(r3, {1, 2, 1}))); } private: @@ -717,7 +719,9 @@ int main(int argc, char **argv) { t.set_feature(f); } } + if (t == Target("hexagon-32-noos")) { + Halide::Internal::Test::Sharder::accept_sharded_status(); printf("[SKIP] No HVX target enabled.\n"); return 0; } @@ -726,7 +730,6 @@ int main(int argc, char **argv) { if (argc > 1) { test_hvx.filter = argv[1]; - test_hvx.set_num_threads(1); } if (getenv("HL_SIMD_OP_CHECK_FILTER")) { @@ -739,20 +742,6 @@ int main(int argc, char **argv) { // Remove some features like simd_op_check.cpp used to do. - // TODO: multithreading here is the cause of https://github.com/halide/Halide/issues/3669; - // the fundamental issue is that we make one set of ImageParams to construct many - // Exprs, then realize those Exprs on arbitrary threads; it is known that sharing - // one Func across multiple threads is not guaranteed to be safe, and indeed, TSAN - // reports data races, of which some are likely 'benign' (e.g. Function.freeze) but others - // are highly suspect (e.g. Function.lock_loop_levels). Since multithreading here - // was added just to avoid having this test be the last to finish, the expedient 'fix' - // for now is to remove the multithreading. A proper fix could be made by restructuring this - // test so that every Expr constructed for testing was guaranteed to share no Funcs - // (Function.deep_copy() perhaps). Of course, it would also be desirable to allow Funcs, Exprs, etc - // to be usable across multiple threads, but that is a major undertaking that is - // definitely not worthwhile for present Halide usage patterns. - test_hvx.set_num_threads(1); - if (argc > 2) { // Don't forget: if you want to run the standard tests to a specific output // directory, you'll need to invoke with the first arg enclosed diff --git a/test/correctness/simplify.cpp b/test/correctness/simplify.cpp index 2771c26c1206..aa10815663d2 100644 --- a/test/correctness/simplify.cpp +++ b/test/correctness/simplify.cpp @@ -2348,6 +2348,10 @@ int main(int argc, char **argv) { Evaluate::make(0)); } + { + check(concat_bits({x}), x); + } + // Check a bounds-related fuzz tester failure found in issue https://github.com/halide/Halide/issues/3764 check(Let::make("b", 105, 336 / max(cast(cast(Variable::make(Int(32), "b"))), 38) + 29), 32); diff --git a/test/correctness/skip_stages.cpp b/test/correctness/skip_stages.cpp index 7f0dea3e962d..64be635b507e 100644 --- a/test/correctness/skip_stages.cpp +++ b/test/correctness/skip_stages.cpp @@ -3,14 +3,8 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - int call_count[4]; -extern "C" DLLEXPORT int call_counter(int x, int idx) { +extern "C" HALIDE_EXPORT_SYMBOL int call_counter(int x, int idx) { call_count[idx]++; return x; } diff --git a/test/correctness/skip_stages_external_array_functions.cpp b/test/correctness/skip_stages_external_array_functions.cpp index f865fd79340b..f96098737f1c 100644 --- a/test/correctness/skip_stages_external_array_functions.cpp +++ b/test/correctness/skip_stages_external_array_functions.cpp @@ -3,15 +3,9 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - int bounds_query_count[4]; int call_count[4]; -extern "C" DLLEXPORT int call_counter(halide_buffer_t *input, int x, int idx, halide_buffer_t *output) { +extern "C" HALIDE_EXPORT_SYMBOL int call_counter(halide_buffer_t *input, int x, int idx, halide_buffer_t *output) { if (input->is_bounds_query()) { bounds_query_count[idx]++; input->dim[0] = output->dim[0]; diff --git a/test/correctness/sliding_backwards.cpp b/test/correctness/sliding_backwards.cpp index 14b07177bcd7..79f2218fc96f 100644 --- a/test/correctness/sliding_backwards.cpp +++ b/test/correctness/sliding_backwards.cpp @@ -3,14 +3,8 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - int call_counter = 0; -extern "C" DLLEXPORT int count(int arg) { +extern "C" HALIDE_EXPORT_SYMBOL int count(int arg) { call_counter++; return arg; } diff --git a/test/correctness/sliding_over_guard_with_if.cpp b/test/correctness/sliding_over_guard_with_if.cpp index 4716354f8691..1f8bc5c62376 100644 --- a/test/correctness/sliding_over_guard_with_if.cpp +++ b/test/correctness/sliding_over_guard_with_if.cpp @@ -6,14 +6,8 @@ using namespace Halide; using namespace Halide::Tools; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - int call_count = 0; -extern "C" DLLEXPORT int call_counter(int x, int y) { +extern "C" HALIDE_EXPORT_SYMBOL int call_counter(int x, int y) { call_count++; return x; } diff --git a/test/correctness/sliding_reduction.cpp b/test/correctness/sliding_reduction.cpp index 40e95252bca7..157a9014f7d0 100644 --- a/test/correctness/sliding_reduction.cpp +++ b/test/correctness/sliding_reduction.cpp @@ -3,14 +3,8 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - int counter = 0; -extern "C" DLLEXPORT int call_count(int x) { +extern "C" HALIDE_EXPORT_SYMBOL int call_count(int x) { counter++; assert(counter > 0); return 99; diff --git a/test/correctness/sliding_window.cpp b/test/correctness/sliding_window.cpp index 04fd70635b55..026dc2a1ddc9 100644 --- a/test/correctness/sliding_window.cpp +++ b/test/correctness/sliding_window.cpp @@ -3,14 +3,8 @@ using namespace Halide; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - int count = 0; -extern "C" DLLEXPORT int call_counter(int x, int y) { +extern "C" HALIDE_EXPORT_SYMBOL int call_counter(int x, int y) { count++; return 0; } diff --git a/test/correctness/storage_folding.cpp b/test/correctness/storage_folding.cpp index 20a937c10315..300d45af5a19 100644 --- a/test/correctness/storage_folding.cpp +++ b/test/correctness/storage_folding.cpp @@ -43,14 +43,8 @@ bool check_expected_mallocs(const std::vector &expected) { return true; } -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - // An extern stage that copies input -> output -extern "C" DLLEXPORT int simple_buffer_copy(halide_buffer_t *in, halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int simple_buffer_copy(halide_buffer_t *in, halide_buffer_t *out) { if (in->is_bounds_query()) { memcpy(in->dim, out->dim, out->dimensions * sizeof(halide_dimension_t)); } else { @@ -60,7 +54,7 @@ extern "C" DLLEXPORT int simple_buffer_copy(halide_buffer_t *in, halide_buffer_t } // An extern stage accesses the input in a non-monotonic way in the y dimension. -extern "C" DLLEXPORT int zigzag_buffer_copy(halide_buffer_t *in, halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int zigzag_buffer_copy(halide_buffer_t *in, halide_buffer_t *out) { if (in->is_bounds_query()) { memcpy(in->dim, out->dim, out->dimensions * sizeof(halide_dimension_t)); diff --git a/test/correctness/strided_load.cpp b/test/correctness/strided_load.cpp index 4de5f17168cb..ab3d2a5dd329 100644 --- a/test/correctness/strided_load.cpp +++ b/test/correctness/strided_load.cpp @@ -23,7 +23,7 @@ int main(int argc, char **argv) { g(x) = f(2 * x); g.compute_root().vectorize(x, 16).bound(x, 0, 425); // 24 * 2 = 48 < 49 - //g.compile_to_assembly("/dev/stdout", std::vector(), "g"); + // g.compile_to_assembly("/dev/stdout", std::vector(), "g"); g.realize({425}); diff --git a/test/correctness/tiled_matmul.cpp b/test/correctness/tiled_matmul.cpp index 7fbeedef3ecc..6b5e949dd518 100644 --- a/test/correctness/tiled_matmul.cpp +++ b/test/correctness/tiled_matmul.cpp @@ -62,12 +62,33 @@ struct make_int_t { } }; -template -bool matmul() { - constexpr int row = 16; - constexpr int col = 16; - constexpr int acc = 16; +template +void print_mat(const Buffer &buf, int rows, int cols) { + using cast_T = std::conditional_t, int, T>; + for (int j = 0; j != rows; ++j) { + for (int i = 0; i != cols; ++i) { + std::cout << static_cast(buf(i, j)) << " "; + } + std::cout << std::endl; + } +} + +template +void print_mat_rhs(const Buffer &buf, int rows, int cols) { + using cast_T = std::conditional_t, int, T>; + for (int j = 0; j != (rows / (4 / sizeof(T))); ++j) { + for (int k = 0; k != (4 / sizeof(T)); ++k) { + for (int i = 0; i != cols; ++i) { + std::cout << static_cast(buf(k, i, j)) << " "; + } + + std::cout << std::endl; + } + } +} +template +bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { Buffer A_buf(acc, row); Buffer B_buf(4, col, acc / 4); @@ -78,10 +99,6 @@ bool matmul() { mm(x, y) = cast(0); mm(x, y) += cast(A_buf(r, y)) * cast(B_buf(r % 4, x, r / 4)); - constexpr int tile_x = 8; - constexpr int tile_y = 8; - constexpr int tile_r = 4; - Var rxi("rxi"), ryi("ryi"); RVar rri("rri"), rro("rro"); @@ -118,6 +135,15 @@ bool matmul() { result.realize(out); + // uncomment to check the matrices + // std::cout << "Matrix A\n"; + // print_mat(A_buf, row, acc); + // std::cout << "Matrix B\n"; + // print_mat_rhs(B_buf, acc, col); + + // std::cout << "result\n"; + // print_mat(out, row, col); + for (int j = 0; j < row; ++j) { for (int i = 0; i < col; ++i) { int32_t val = 0; @@ -126,21 +152,18 @@ bool matmul() { } if (val != out(i, j)) { std::cerr << "Invalid result at " << i << ", " << j << "\n" - << out(i, j) << " != " << val << "\n"; + << out(i, j) << " != " << val << "\n" + << "Matrix dims: " << row << "x" << col << "x" << acc << "\nTile dims: " << tile_x << "x" << tile_y << "x" << tile_r << "\n"; return false; } } } + std::cout << "Success!\n"; return true; } -bool matmul_bf16() { - // lhs: 32x16, rhs: 16x32 - const int row = 32; - const int col = 32; - const int acc = 16; - +bool matmul_bf16(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { Var x("x"), y("y"); Buffer A(acc, row); Buffer B(2, col, acc / 2); @@ -151,10 +174,6 @@ bool matmul_bf16() { mm(x, y) = cast(0); mm(x, y) += cast(cast(A(r.x, y))) * cast(B(r.x % 2, x, r.x / 2)); - int tile_x = 8; - int tile_y = 8; - int tile_r = 2; - Var rxi("rxi"), ryi("ryi"); RVar rri("rri"), rro("rro"); @@ -190,25 +209,36 @@ bool matmul_bf16() { Buffer out(col, row); // Uncomment to check the asm - //result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul_bf16.ll", {A, B}, target); - //result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); + // result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul_bf16.ll", {A, B}, target); + // result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); result.realize(out); + // uncomment to check the matrices + // std::cout << "Matrix A\n"; + // print_mat(A, row, acc); + // std::cout << "Matrix B\n"; + // print_mat_rhs(B, acc, col); + + // std::cout << "result\n"; + // print_mat(out, row, col); + for (int j = 0; j < row; ++j) { for (int i = 0; i < col; ++i) { float val = 0.f; for (int k = 0; k < acc; ++k) { val += static_cast(A(k, j)) * static_cast(B(k % 2, i, k / 2)); } - if (!equal_eps(val, out(i, j), 0.01f)) { + if (!equal_eps(val, out(i, j), 0.03f)) { std::cerr << "Invalid result at " << i << ", " << j << "\n" - << out(i, j) << " != " << val << "\n"; + << out(i, j) << " != " << val << "\n" + << "Matrix dims: " << row << "x" << col << "x" << acc << "\nTile dims: " << tile_x << "x" << tile_y << "x" << tile_r << "\n"; return false; } } } + std::cout << "Success!\n"; return true; } @@ -217,6 +247,10 @@ auto matmul_us = &matmul; auto matmul_su = &matmul; auto matmul_uu = &matmul; +bool run_tests(bool (*fn)(int, int, int, int, int, int), int element_width) { + return fn(2, 2, 16, 2, 2, 8 / element_width) && fn(4, 4, 8, 4, 4, 8 / element_width) && fn(32, 32, 32, 8, 8, 8 / element_width) && fn(32, 32, 32, 8, 8, 4 / element_width); +} + int main(int argc, char **argv) { Target t = get_jit_target_from_environment(); if (!t.has_feature(Target::AVX512_SapphireRapids)) { @@ -225,41 +259,29 @@ int main(int argc, char **argv) { } printf("Running AMX matmul (signed/signed)\n"); - if (!matmul_ss()) { + if (!run_tests(matmul_ss, 1)) { return -1; - } else { - printf("Success!\n"); } - // llvm >= 13.0 is required for unsigned and float AMX instructions - if (Halide::Internal::get_llvm_version() >= 130) { - printf("Running AMX matmul (signed/unsigned)\n"); - if (!matmul_su()) { - return -1; - } else { - printf("Success!\n"); - } + printf("Running AMX matmul (signed/unsigned)\n"); + if (!run_tests(matmul_su, 1)) { + return -1; + } - printf("Running AMX matmul (unsigned/signed)\n"); - if (!matmul_us()) { - return -1; - } else { - printf("Success!\n"); - } + printf("Running AMX matmul (unsigned/signed)\n"); + if (!run_tests(matmul_us, 1)) { + return -1; + } - printf("Running AMX matmul (unsigned/unsigned)\n"); - if (!matmul_uu()) { - return -1; - } else { - printf("Success!\n"); - } + printf("Running AMX matmul (unsigned/unsigned)\n"); + if (!run_tests(matmul_uu, 1)) { + return -1; + } - printf("Running AMX matmul (bf16)\n"); - if (!matmul_bf16()) { - return -1; - } else { - printf("Success!\n"); - } + printf("Running AMX matmul (bf16)\n"); + if (!run_tests(matmul_bf16, 2)) { + return -1; } + return 0; } \ No newline at end of file diff --git a/test/correctness/tracing_stack.cpp b/test/correctness/tracing_stack.cpp index d3fe04d548b5..6ea21e48ce57 100644 --- a/test/correctness/tracing_stack.cpp +++ b/test/correctness/tracing_stack.cpp @@ -64,6 +64,13 @@ void signal_handler(int signum) { } // namespace int main(int argc, char **argv) { +#ifdef HALIDE_INTERNAL_USING_ASAN + // ASAN also needs to intercept the SIGSEGV signal handler; + // we could probably make these work together, but it's + // also probably not worth the effort. + printf("[SKIP] tracing_stack does not run under ASAN.\n"); + return 0; +#endif signal(SIGSEGV, signal_handler); signal(SIGBUS, signal_handler); diff --git a/test/correctness/typed_func.cpp b/test/correctness/typed_func.cpp new file mode 100644 index 000000000000..8eeca0829e3e --- /dev/null +++ b/test/correctness/typed_func.cpp @@ -0,0 +1,141 @@ +#include "Halide.h" +#include + +using namespace Halide; +using namespace Halide::Internal; + +int main(int argc, char **argv) { + Var x("x"), y("y"); + { + Func f("f"); + + assert(!f.defined()); + // undefined funcs assert-fail for these calls. + // but return 0 for outputs() and dimensions(). + // assert(f.type() == Int(32)); + // assert(f.outputs() == 0); + // assert(f.dimensions() == 0); + } + + // Verify that func with type-and-dim specifications + // return appropriate types, dims, etc even though the func is "undefined" + { + Func f(Int(32), 2, "f"); + + assert(!f.defined()); + const std::vector expected = {Int(32)}; + assert(f.type() == expected[0]); + assert(f.types() == expected); + assert(f.outputs() == 1); + assert(f.dimensions() == 2); + } + + // Same, but for Tuples. + { + Func f({Int(32), Float(64)}, 3, "f"); + + const std::vector expected = {Int(32), Float(64)}; + assert(!f.defined()); + // assert(f.type() == expected[0]); // will assert-fail + assert(f.types() == expected); + assert(f.outputs() == 2); + assert(f.dimensions() == 3); + } + + // Verify that the Func for an ImageParam gets required-types, etc, set. + { + ImageParam im(Int(32), 2, "im"); + Func f = im; + + // Have to peek directly at 'required_type', etc since the Func + // actually is defined to peek at a buffer of the right types + const std::vector expected = {Int(32)}; + assert(f.function().required_types() == expected); + assert(f.function().required_dimensions() == 2); + } + + // Verify that we can call output_buffer() on an undefined Func, + // but only if it has type-and-dim specifications. + { + Func f(Int(32), 2, "f"); + + const auto o = f.output_buffer(); + f.output_buffer().dim(0).set_bounds(0, 10).dim(1).set_bounds(0, 10); + + // And now we can define the Func *after* setting values in output_buffer() + f(x, y) = x + y; + + auto r = f.realize({10, 10}); // will assert-fail for values other than 10x10 + Buffer b = r[0]; + b.for_each_element([&](int x, int y) { + assert(b(x, y) == x + y); + }); + } + + // Verify that update stages defined via += and friends *don't* require + // the RHS type to match the LHS type (whether or not the pure definition + // is implicitly defined) + { + Func f(Int(32), 2, "f"); + + f(x, y) = cast(1); + f(x, y) += cast(x + y); + + auto r = f.realize({10, 10}); + Buffer b = r[0]; + b.for_each_element([&](int x, int y) { + assert(b(x, y) == 1 + (uint8_t)(x + y)); + }); + } + + { + Func f(Int(32), 2, "f"); + + // f(x, y) = cast(0); // leave out, so Halide injects the implicit init + f(x, y) += cast(x + y); + + auto r = f.realize({10, 10}); + Buffer b = r[0]; + b.for_each_element([&](int x, int y) { + assert(b(x, y) == 0 + (uint8_t)(x + y)); + }); + } + + // Same, but with Tuples + { + Func f({Int(32), Int(8)}, 2, "f"); + + f(x, y) = Tuple(cast(1), cast(2)); + f(x, y) += Tuple(cast(x + y), cast(x - y)); + + auto r = f.realize({10, 10}); + Buffer b0 = r[0]; + Buffer b1 = r[1]; + b0.for_each_element([&](int x, int y) { + assert(b0(x, y) == 1 + (uint8_t)(x + y)); + }); + b1.for_each_element([&](int x, int y) { + assert(b1(x, y) == 2 + (int8_t)(x - y)); + }); + } + + { + Func f({Int(32), Int(8)}, 2, "f"); + + // f(x, y) = Tuple(cast(1), cast(2)); // leave out, so Halide injects the implicit init + f(x, y) += Tuple(cast(x + y), cast(x - y)); + + auto r = f.realize({10, 10}); + Buffer b0 = r[0]; + Buffer b1 = r[1]; + b0.for_each_element([&](int x, int y) { + assert(b0(x, y) == 0 + (uint8_t)(x + y)); + }); + b1.for_each_element([&](int x, int y) { + assert(b1(x, y) == 0 + (int8_t)(x - y)); + }); + } + + printf("Success!\n"); + return 0; +} diff --git a/test/correctness/unroll_huge_mux.cpp b/test/correctness/unroll_huge_mux.cpp index 9a6307d68414..233ee038c4e8 100644 --- a/test/correctness/unroll_huge_mux.cpp +++ b/test/correctness/unroll_huge_mux.cpp @@ -3,6 +3,11 @@ using namespace Halide; int main(int argc, char **argv) { +#ifdef HALIDE_INTERNAL_USING_ASAN + printf("[SKIP] unroll_huge_mux requires set_compiler_stack_size() to work properly, which is disabled under ASAN.\n"); + return 0; +#endif + Func f; Var x; diff --git a/test/correctness/vector_cast.cpp b/test/correctness/vector_cast.cpp index b15819ac9a50..c5afbfc330d0 100644 --- a/test/correctness/vector_cast.cpp +++ b/test/correctness/vector_cast.cpp @@ -1,5 +1,6 @@ #include "Halide.h" -#include +#include "test_sharding.h" + #include using namespace Halide; @@ -34,6 +35,11 @@ bool is_type_supported(int vec_width, const Target &target) { template bool test(int vec_width, const Target &target) { + // Useful for debugging; leave in (commented out) + // printf("Test %s x %d -> %s x %d\n", + // string_of_type(), vec_width, + // string_of_type(), vec_width); + if (!is_type_supported(vec_width, target) || !is_type_supported(vec_width, target)) { // Type not supported, return pass. return true; @@ -62,7 +68,7 @@ bool test(int vec_width, const Target &target) { } else { if (target.has_feature(Target::HVX)) { // TODO: Non-native vector widths hang the compiler here. - //f.hexagon(); + // f.hexagon(); } if (vec_width > 1) { f.vectorize(x, vec_width); @@ -101,18 +107,21 @@ bool test(int vec_width, const Target &target) { return true; } +struct Task { + std::function fn; +}; + template -bool test_all(int vec_width, const Target &target) { - bool success = true; - success = success && test(vec_width, target); - success = success && test(vec_width, target); - success = success && test(vec_width, target); - success = success && test(vec_width, target); - success = success && test(vec_width, target); - success = success && test(vec_width, target); - success = success && test(vec_width, target); - success = success && test(vec_width, target); - return success; +void add_all(int vec_width, const Target &target, std::vector &tasks) { + tasks.push_back({[=]() { return test(vec_width, target); }}); + tasks.push_back({[=]() { return test(vec_width, target); }}); + tasks.push_back({[=]() { return test(vec_width, target); }}); + tasks.push_back({[=]() { return test(vec_width, target); }}); + tasks.push_back({[=]() { return test(vec_width, target); }}); + tasks.push_back({[=]() { return test(vec_width, target); }}); + tasks.push_back({[=]() { return test(vec_width, target); }}); + tasks.push_back({[=]() { return test(vec_width, target); }}); + tasks.push_back({[=]() { return test(vec_width, target); }}); } int main(int argc, char **argv) { @@ -129,34 +138,33 @@ int main(int argc, char **argv) { Target target = get_jit_target_from_environment(); // We only test power-of-two vector widths for now - Halide::Internal::ThreadPool pool; - std::vector> futures; int vec_width_max = 64; if (target.arch == Target::WebAssembly) { // The wasm jit is very slow, so shorten this test here. vec_width_max = 16; } + std::vector tasks; for (int vec_width = 1; vec_width <= vec_width_max; vec_width *= 2) { - futures.push_back(pool.async([=]() { - bool success = true; - success = success && test_all(vec_width, target); - success = success && test_all(vec_width, target); - success = success && test_all(vec_width, target); - success = success && test_all(vec_width, target); - success = success && test_all(vec_width, target); - success = success && test_all(vec_width, target); - success = success && test_all(vec_width, target); - success = success && test_all(vec_width, target); - return success; - })); + add_all(vec_width, target, tasks); + add_all(vec_width, target, tasks); + add_all(vec_width, target, tasks); + add_all(vec_width, target, tasks); + add_all(vec_width, target, tasks); + add_all(vec_width, target, tasks); + add_all(vec_width, target, tasks); + add_all(vec_width, target, tasks); } - bool ok = true; - for (auto &f : futures) { - ok &= f.get(); + using Sharder = Halide::Internal::Test::Sharder; + Sharder sharder; + for (size_t t = 0; t < tasks.size(); t++) { + if (!sharder.should_run(t)) continue; + const auto &task = tasks.at(t); + if (!task.fn()) { + exit(-1); + } } - if (!ok) return -1; printf("Success!\n"); return 0; } diff --git a/test/correctness/vector_math.cpp b/test/correctness/vector_math.cpp index b8d846c3b1e4..d0c6ed530c9d 100644 --- a/test/correctness/vector_math.cpp +++ b/test/correctness/vector_math.cpp @@ -1,7 +1,8 @@ #include "Halide.h" +#include "test_sharding.h" + #include #include -#include #include #include #include @@ -713,32 +714,42 @@ bool test(int lanes, int seed) { } int main(int argc, char **argv) { - int seed = argc > 1 ? atoi(argv[1]) : time(nullptr); std::cout << "vector_math test seed: " << seed << std::endl; + struct Task { + std::function fn; + int lanes; + int seed; + }; + // Only native vector widths - llvm doesn't handle others well - Halide::Internal::ThreadPool pool; - std::vector> futures; - futures.push_back(pool.async(test, 4, seed)); - futures.push_back(pool.async(test, 8, seed)); - futures.push_back(pool.async(test, 2, seed)); - futures.push_back(pool.async(test, 16, seed)); - futures.push_back(pool.async(test, 16, seed)); - futures.push_back(pool.async(test, 8, seed)); - futures.push_back(pool.async(test, 8, seed)); - futures.push_back(pool.async(test, 4, seed)); - futures.push_back(pool.async(test, 4, seed)); - futures.push_back(pool.async(test, 8, seed)); - futures.push_back(pool.async(test, 16, seed)); - futures.push_back(pool.async(test, 8, seed)); - futures.push_back(pool.async(test, 16, seed)); - bool ok = true; - for (auto &f : futures) { - ok &= f.get(); + std::vector tasks = { + {test, 4, seed}, + {test, 8, seed}, + {test, 2, seed}, + {test, 16, seed}, + {test, 16, seed}, + {test, 8, seed}, + {test, 8, seed}, + {test, 4, seed}, + {test, 4, seed}, + {test, 8, seed}, + {test, 16, seed}, + {test, 8, seed}, + {test, 16, seed}, + }; + + using Sharder = Halide::Internal::Test::Sharder; + Sharder sharder; + for (size_t t = 0; t < tasks.size(); t++) { + if (!sharder.should_run(t)) continue; + const auto &task = tasks.at(t); + if (!task.fn(task.lanes, task.seed)) { + exit(-1); + } } - if (!ok) return -1; printf("Success!\n"); return 0; } diff --git a/test/correctness/vector_reductions.cpp b/test/correctness/vector_reductions.cpp index d4d2acc43984..1f3bf90e73c5 100644 --- a/test/correctness/vector_reductions.cpp +++ b/test/correctness/vector_reductions.cpp @@ -1,8 +1,16 @@ #include "Halide.h" +#include "test_sharding.h" using namespace Halide; -int main(int argc, char **argv) { +namespace { + +struct Task { + Target target; + std::function fn; +}; + +void add_tasks(const Target &target, std::vector &tasks) { for (int dst_lanes : {1, 3}) { for (int reduce_factor : {2, 3, 4}) { std::vector types = @@ -103,26 +111,96 @@ int main(int argc, char **argv) { .vectorize(rx); ref.compute_root(); - RDom c(0, 128); - Expr err = cast(maximum(absd(f(c), ref(c)))); + const auto fn = [=]() { + // Useful for debugging; leave in (commented out) + // std::cout << "Testing: " + // << " target: " << target + // << " dst_lanes: " << dst_lanes + // << " reduce_factor " << reduce_factor + // << " src_type " << src_type + // << " widen_factor " << widen_factor + // << " dst_type " << dst_type + // << " op " << op + // << "\n"; - double e = evaluate(err); + RDom c(0, 128); - if (e > 1e-3) { - std::cerr - << "Horizontal reduction produced different output when vectorized!\n" - << "Maximum error = " << e << "\n" - << "Reducing from " << src_type.with_lanes(src_lanes) - << " to " << dst_type.with_lanes(dst_lanes) << "\n" - << "RHS: " << f.update_value() << "\n"; - exit(-1); - } + // Func.evaluate() doesn't let you specify a Target (!), + // so let's use Func.realize() instead. + Func err("err"); + err() = cast(maximum(absd(f(c), ref(c)))); + Buffer err_im = err.realize({}, target); + double e = err_im(); + + if (e > 1e-3) { + std::cerr + << "Horizontal reduction produced different output when vectorized!\n" + << "Maximum error = " << e << "\n" + << "Reducing from " << src_type.with_lanes(src_lanes) + << " to " << dst_type.with_lanes(dst_lanes) << "\n" + << "RHS: " << f.update_value() << "\n"; + exit(-1); + } + }; + tasks.push_back({target, fn}); } } } } } +} + +} // namespace + +int main(int argc, char **argv) { + Target target = get_jit_target_from_environment(); + + std::vector tasks; + add_tasks(target, tasks); + + if (target.arch == Target::X86) { + // LLVM has had SIMD codegen errors that we missed because we didn't test against + // multiple SIMD architectures, using just 'host' instead. To remedy this, we'll + // re-run this multiple times, downgrading the SIMD successively, to ensure we get + // test coverage. Note that this doesn't attempt to be exhaustive -- there are too + // many permutations to really test, especially with AVX512 -- but this way we + // can get at least baseline coverage for the major variants. + // + // (Note also that our codegen for x86 implicitly 'fills in' required prerequisites, + // e.g. if you specify a target with AVX2, the codegen will automatically include + // AVX and SSE41 as well.) + if (target.has_feature(Target::AVX512)) { + Target avx2_target(target.os, target.arch, target.bits, {Target::AVX2}); + add_tasks(avx2_target, tasks); + } + if (target.has_feature(Target::AVX2)) { + Target sse41_target(target.os, target.arch, target.bits, {Target::AVX}); + add_tasks(sse41_target, tasks); + } + if (target.has_feature(Target::AVX)) { + Target sse41_target(target.os, target.arch, target.bits, {Target::SSE41}); + add_tasks(sse41_target, tasks); + } + if (target.has_feature(Target::SSE41)) { + // Halide assumes that all x86 targets have at least sse2 + Target sse2_target(target.os, target.arch, target.bits); + add_tasks(sse2_target, tasks); + } + } + + using Sharder = Halide::Internal::Test::Sharder; + Sharder sharder; + Target prev_target; + for (size_t t = 0; t < tasks.size(); t++) { + if (!sharder.should_run(t)) continue; + const auto &task = tasks.at(t); + if (task.target != prev_target) { + std::cout << "vector_reductions: Testing with " << task.target << "\n"; + prev_target = task.target; + } + task.fn(); + } - printf("Success!\n"); + std::cout << "Success!\n"; return 0; } diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index ee75855332e3..ee9d47923469 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -9,8 +9,8 @@ tests(GROUPS error auto_schedule_no_parallel.cpp auto_schedule_no_reorder.cpp autodiff_unbounded.cpp - bad_bound_storage.cpp bad_bound.cpp + bad_bound_storage.cpp bad_compute_at.cpp bad_compute_with.cpp bad_compute_with_invalid_specialization.cpp @@ -29,9 +29,15 @@ tests(GROUPS error bad_store_at.cpp broken_promise.cpp buffer_larger_than_two_gigs.cpp + callable_bad_arguments.cpp + callable_bad_values_passed.cpp + callable_typed_bad_arguments.cpp + callable_typed_bad_arguments_buffer_dims.cpp + callable_typed_bad_arguments_buffer_type.cpp clamp_out_of_range.cpp compute_with_crossing_edges1.cpp compute_with_crossing_edges2.cpp + compute_with_fuse_in_specialization.cpp constrain_wrong_output_buffer.cpp constraint_uses_non_param.cpp define_after_realize.cpp @@ -44,18 +50,26 @@ tests(GROUPS error five_d_gpu_buffer.cpp float_arg.cpp forward_on_undefined_buffer.cpp + func_expr_dim_mismatch.cpp + func_expr_type_mismatch.cpp + func_expr_update_type_mismatch.cpp + func_extern_dim_mismatch.cpp + func_extern_type_mismatch.cpp + func_tuple_dim_mismatch.cpp + func_tuple_types_mismatch.cpp + func_tuple_update_types_mismatch.cpp implicit_args.cpp impossible_constraints.cpp - init_def_should_be_all_vars.cpp incomplete_target.cpp + init_def_should_be_all_vars.cpp inspect_loop_level.cpp lerp_float_weight_out_of_range.cpp lerp_mismatch.cpp lerp_signed_weight.cpp memoize_different_compute_store.cpp memoize_redefine_eviction_key.cpp - metal_vector_too_large.cpp metal_threads_too_large.cpp + metal_vector_too_large.cpp missing_args.cpp no_default_device.cpp nonexistent_update_stage.cpp @@ -64,6 +78,7 @@ tests(GROUPS error pointer_arithmetic.cpp race_condition.cpp rdom_undefined.cpp + rdom_where_races.cpp realization_with_too_many_outputs.cpp realize_constantly_larger_than_two_gigs.cpp reduction_bounds.cpp @@ -89,11 +104,11 @@ tests(GROUPS error undefined_pipeline_realize.cpp undefined_rdom_dimension.cpp unknown_target.cpp + vector_tile.cpp vectorize_dynamic.cpp vectorize_too_little.cpp vectorize_too_much.cpp vectorized_extern.cpp - vector_tile.cpp wrap_custom_after_shared.cpp wrap_frozen.cpp wrapper_never_used.cpp diff --git a/test/error/auto_schedule_no_parallel.cpp b/test/error/auto_schedule_no_parallel.cpp index 74e2b269025f..2519619a3b1b 100644 --- a/test/error/auto_schedule_no_parallel.cpp +++ b/test/error/auto_schedule_no_parallel.cpp @@ -25,7 +25,11 @@ int main(int argc, char **argv) { // This should throw an error since auto-scheduler does not currently // support partial schedules +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif printf("Success!\n"); return 0; diff --git a/test/error/auto_schedule_no_reorder.cpp b/test/error/auto_schedule_no_reorder.cpp index 8f39114ee9ea..d9fb344473e4 100644 --- a/test/error/auto_schedule_no_reorder.cpp +++ b/test/error/auto_schedule_no_reorder.cpp @@ -25,7 +25,11 @@ int main(int argc, char **argv) { // This should throw an error since auto-scheduler does not currently // support partial schedules +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API p.auto_schedule(target); +#else + p.apply_autoscheduler(target, {"Mullapudi2016"}); +#endif printf("Success!\n"); return 0; diff --git a/test/error/callable_bad_arguments.cpp b/test/error/callable_bad_arguments.cpp new file mode 100644 index 000000000000..edcec6f20523 --- /dev/null +++ b/test/error/callable_bad_arguments.cpp @@ -0,0 +1,28 @@ +#include "Halide.h" +#include +#include + +using namespace Halide; + +void check(int r) { + assert(r == 0); +} + +int main(int argc, char **argv) { + Param p_int(42); + Param p_float(1.0f); + ImageParam p_img(UInt(8), 2); + + Var x("x"), y("y"); + Func f("f"); + + f(x, y) = p_img(x, y) + cast(p_int / p_float); + + // Should fail with "Generated code refers to parameter p_int, which was not found in the argument list." + Callable c = f.compile_to_callable({p_img, p_float}); + + // Shouldn't get here, but if we do, return success, which is a failure... + + printf("Success!\n"); + return 0; +} diff --git a/test/error/callable_bad_values_passed.cpp b/test/error/callable_bad_values_passed.cpp new file mode 100644 index 000000000000..e786bb678f13 --- /dev/null +++ b/test/error/callable_bad_values_passed.cpp @@ -0,0 +1,31 @@ +#include "Halide.h" +#include +#include + +using namespace Halide; + +void check(int r) { + assert(r == 0); +} + +int main(int argc, char **argv) { + Param p_int(42); + Param p_float(1.0f); + ImageParam p_img(UInt(8), 2); + + Var x("x"), y("y"); + Func f("f"); + + f(x, y) = p_img(x, y) + cast(p_int / p_float); + + Buffer in1(10, 10), result1(10, 10); + in1.fill(0); + + Callable c = f.compile_to_callable({p_img, p_int, p_float}); + + // Should fail with something like "Argument 2 of 4 ('p_int') was expected to be a scalar of type 'int32'." + int r = c(in1, 3.1415927, 1.0f, result1); + _halide_user_assert(r == 0); + + printf("Success!\n"); +} diff --git a/test/error/callable_typed_bad_arguments.cpp b/test/error/callable_typed_bad_arguments.cpp new file mode 100644 index 000000000000..29a0405e0ab9 --- /dev/null +++ b/test/error/callable_typed_bad_arguments.cpp @@ -0,0 +1,31 @@ +#include "Halide.h" +#include +#include + +using namespace Halide; + +void check(int r) { + assert(r == 0); +} + +int main(int argc, char **argv) { + Param p_int(42); + Param p_float(1.0f); + ImageParam p_img(UInt(8), 2); + + Var x("x"), y("y"); + Func f("f"); + + f(x, y) = p_img(x, y) + cast(p_int / p_float); + + Buffer in1(10, 10), result1(10, 10); + in1.fill(0); + + // Should fail with "Error defining 'f': Argument 1 of 4 ('p_int') was expected to be a scalar of type 'int32'." + auto c = f.compile_to_callable({p_int, p_float, p_img}) + .make_std_function, uint8_t, float, Buffer>(); + + // Shouldn't get here, but if we do, return success, which is a failure... + + printf("Success!\n"); +} diff --git a/test/error/callable_typed_bad_arguments_buffer_dims.cpp b/test/error/callable_typed_bad_arguments_buffer_dims.cpp new file mode 100644 index 000000000000..d27c64b55e8a --- /dev/null +++ b/test/error/callable_typed_bad_arguments_buffer_dims.cpp @@ -0,0 +1,31 @@ +#include "Halide.h" +#include +#include + +using namespace Halide; + +void check(int r) { + assert(r == 0); +} + +int main(int argc, char **argv) { + Param p_int(42); + Param p_float(1.0f); + ImageParam p_img(UInt(8), 2); + + Var x("x"), y("y"); + Func f("f"); + + f(x, y) = p_img(x, y) + cast(p_int / p_float); + + Buffer in1(10, 10), result1(10, 10); + in1.fill(0); + + // Should fail with "Error defining 'f': Argument 1 of 4 ('p_img') was expected to be a buffer of type 'uint8' and dimension 2." + auto c = f.compile_to_callable({p_img, p_int, p_float}) + .make_std_function, int32_t, float, Buffer>(); + + // Shouldn't get here, but if we do, return success, which is a failure... + + printf("Success!\n"); +} diff --git a/test/error/callable_typed_bad_arguments_buffer_type.cpp b/test/error/callable_typed_bad_arguments_buffer_type.cpp new file mode 100644 index 000000000000..e7c1d081d286 --- /dev/null +++ b/test/error/callable_typed_bad_arguments_buffer_type.cpp @@ -0,0 +1,31 @@ +#include "Halide.h" +#include +#include + +using namespace Halide; + +void check(int r) { + assert(r == 0); +} + +int main(int argc, char **argv) { + Param p_int(42); + Param p_float(1.0f); + ImageParam p_img(UInt(8), 2); + + Var x("x"), y("y"); + Func f("f"); + + f(x, y) = p_img(x, y) + cast(p_int / p_float); + + Buffer in1(10, 10), result1(10, 10); + in1.fill(0); + + // Should fail with "Error defining 'f': Argument 1 of 4 ('p_img') was expected to be a buffer of type 'uint8' and dimension 2." + auto c = f.compile_to_callable({p_img, p_int, p_float}) + .make_std_function, int32_t, float, Buffer>(); + + // Shouldn't get here, but if we do, return success, which is a failure... + + printf("Success!\n"); +} diff --git a/test/error/compute_with_fuse_in_specialization.cpp b/test/error/compute_with_fuse_in_specialization.cpp new file mode 100644 index 000000000000..6561956b1166 --- /dev/null +++ b/test/error/compute_with_fuse_in_specialization.cpp @@ -0,0 +1,22 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Var x("x"), y("y"), f("f"); + ImageParam in(Int(16), 2, "in"); + Func out0("out0"), out1("out1"); + out0(x, y) = 1 * in(x, y); + out1(x, y) = 2 * in(x, y); + + out0.vectorize(x, 8, TailStrategy::RoundUp); + out1.vectorize(x, 8, TailStrategy::RoundUp).compute_with(out0, x); + + out0.specialize(in.dim(1).stride() == 128).fuse(x, y, f); + Pipeline p({out0, out1}); + p.compile_jit(); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/func_expr_dim_mismatch.cpp b/test/error/func_expr_dim_mismatch.cpp new file mode 100644 index 000000000000..1218f70ecb7b --- /dev/null +++ b/test/error/func_expr_dim_mismatch.cpp @@ -0,0 +1,17 @@ +#include "Halide.h" +#include + +using namespace Halide; +using namespace Halide::Internal; + +int main(int argc, char **argv) { + Var x("x"), y("y"); + Func f(Int(32), 1, "f"); + + f(x, y) = cast(0); + + f.realize({100, 100}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/func_expr_type_mismatch.cpp b/test/error/func_expr_type_mismatch.cpp new file mode 100644 index 000000000000..1337891cacf3 --- /dev/null +++ b/test/error/func_expr_type_mismatch.cpp @@ -0,0 +1,17 @@ +#include "Halide.h" +#include + +using namespace Halide; +using namespace Halide::Internal; + +int main(int argc, char **argv) { + Var x("x"), y("y"); + Func f(Float(32), 1, "f"); + + f(x, y) = cast(0); + + f.realize({100, 100}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/func_expr_update_type_mismatch.cpp b/test/error/func_expr_update_type_mismatch.cpp new file mode 100644 index 000000000000..c2fef5d8ba6c --- /dev/null +++ b/test/error/func_expr_update_type_mismatch.cpp @@ -0,0 +1,18 @@ +#include "Halide.h" +#include + +using namespace Halide; +using namespace Halide::Internal; + +int main(int argc, char **argv) { + Var x("x"), y("y"); + Func f(Float(32), 2, "f"); + + f(x, y) = 0.f; + f(x, y) = cast(0); + + f.realize({100, 100}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/func_extern_dim_mismatch.cpp b/test/error/func_extern_dim_mismatch.cpp new file mode 100644 index 000000000000..41c39ca47f9c --- /dev/null +++ b/test/error/func_extern_dim_mismatch.cpp @@ -0,0 +1,14 @@ +#include "Halide.h" +#include + +using namespace Halide; +using namespace Halide::Internal; + +int main(int argc, char **argv) { + Var x("x"), y("y"); + Func f(Float(32), 1, "f"); + f.define_extern("test", {}, Float(32), {x, y}); + f.realize({100, 100}); + printf("Success!\n"); + return 0; +} diff --git a/test/error/func_extern_type_mismatch.cpp b/test/error/func_extern_type_mismatch.cpp new file mode 100644 index 000000000000..ad137f40aca3 --- /dev/null +++ b/test/error/func_extern_type_mismatch.cpp @@ -0,0 +1,14 @@ +#include "Halide.h" +#include + +using namespace Halide; +using namespace Halide::Internal; + +int main(int argc, char **argv) { + Var x("x"), y("y"); + Func f({UInt(8), Float(64)}, 2, "f"); + f.define_extern("test", {}, {Int(32), Float(32)}, {x, y}); + f.realize({100, 100}); + printf("Success!\n"); + return 0; +} diff --git a/test/error/func_tuple_dim_mismatch.cpp b/test/error/func_tuple_dim_mismatch.cpp new file mode 100644 index 000000000000..79f97217d3b4 --- /dev/null +++ b/test/error/func_tuple_dim_mismatch.cpp @@ -0,0 +1,17 @@ +#include "Halide.h" +#include + +using namespace Halide; +using namespace Halide::Internal; + +int main(int argc, char **argv) { + Var x("x"), y("y"); + Func f({Int(32), Float(32)}, 1, "f"); + + f(x, y) = {cast(0), cast(0)}; + + f.realize({100, 100}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/func_tuple_types_mismatch.cpp b/test/error/func_tuple_types_mismatch.cpp new file mode 100644 index 000000000000..a04eaf45ce71 --- /dev/null +++ b/test/error/func_tuple_types_mismatch.cpp @@ -0,0 +1,17 @@ +#include "Halide.h" +#include + +using namespace Halide; +using namespace Halide::Internal; + +int main(int argc, char **argv) { + Var x("x"), y("y"); + Func f({UInt(8), Float(64)}, 2, "f"); + + f(x, y) = {cast(0), cast(0)}; + + f.realize({100, 100}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/func_tuple_update_types_mismatch.cpp b/test/error/func_tuple_update_types_mismatch.cpp new file mode 100644 index 000000000000..4f8b7894763c --- /dev/null +++ b/test/error/func_tuple_update_types_mismatch.cpp @@ -0,0 +1,18 @@ +#include "Halide.h" +#include + +using namespace Halide; +using namespace Halide::Internal; + +int main(int argc, char **argv) { + Var x("x"), y("y"); + Func f({UInt(8), Float(64)}, 2, "f"); + + f(x, y) = {cast(0), cast(0)}; + f(x, y) = {cast(0), cast(0)}; + + f.realize({100, 100}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/missing_args.cpp b/test/error/missing_args.cpp index fbddb04b7e45..5669cbf8217a 100644 --- a/test/error/missing_args.cpp +++ b/test/error/missing_args.cpp @@ -12,8 +12,8 @@ int main(int argc, char **argv) { f(x) = im(x, x) + arg; std::vector args; - //args.push_back(im); - //args.push_back(arg); + // args.push_back(im); + // args.push_back(arg); f.compile_to_object("f.o", args, "f"); printf("Success!\n"); diff --git a/test/error/rdom_where_races.cpp b/test/error/rdom_where_races.cpp new file mode 100644 index 000000000000..bcc00e78681c --- /dev/null +++ b/test/error/rdom_where_races.cpp @@ -0,0 +1,20 @@ +// https://github.com/halide/Halide/issues/6808 +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f; + Var x; + + RDom r(0, 10); + f(x) = 1; + r.where(f(0) == 1); + f(r) = 2; + + f.update().parallel(r); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/realization_with_too_many_outputs.cpp b/test/error/realization_with_too_many_outputs.cpp index 7a3db8c1b3c3..676279f92c39 100644 --- a/test/error/realization_with_too_many_outputs.cpp +++ b/test/error/realization_with_too_many_outputs.cpp @@ -13,7 +13,7 @@ int main(int argc, char **argv) { Buffer first(10); Buffer second(10); - Realization r(first, second); + Realization r({first, second}); f.realize(r); printf("Success!\n"); diff --git a/test/failing_with_issue/3292_async_specialize.cpp b/test/failing_with_issue/3292_async_specialize.cpp index 213812d953ea..19abb83635a0 100644 --- a/test/failing_with_issue/3292_async_specialize.cpp +++ b/test/failing_with_issue/3292_async_specialize.cpp @@ -19,14 +19,8 @@ void my_free(void *user_context, void *ptr) { free(((void **)ptr)[-1]); } -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - // An extern stage that copies input -> output -extern "C" DLLEXPORT int simple_buffer_copy(halide_buffer_t *in, halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int simple_buffer_copy(halide_buffer_t *in, halide_buffer_t *out) { if (in->is_bounds_query()) { memcpy(in->dim, out->dim, out->dimensions * sizeof(halide_dimension_t)); } else { diff --git a/test/failing_with_issue/3293_storage_folding_async.cpp b/test/failing_with_issue/3293_storage_folding_async.cpp index ebfc51c750c1..6f7dfb842e34 100644 --- a/test/failing_with_issue/3293_storage_folding_async.cpp +++ b/test/failing_with_issue/3293_storage_folding_async.cpp @@ -19,14 +19,8 @@ void my_free(void *user_context, void *ptr) { free(((void **)ptr)[-1]); } -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - // An extern stage that copies input -> output -extern "C" DLLEXPORT int simple_buffer_copy(halide_buffer_t *in, halide_buffer_t *out) { +extern "C" HALIDE_EXPORT_SYMBOL int simple_buffer_copy(halide_buffer_t *in, halide_buffer_t *out) { if (in->is_bounds_query()) { memcpy(in->dim, out->dim, out->dimensions * sizeof(halide_dimension_t)); } else { diff --git a/test/failing_with_issue/CMakeLists.txt b/test/failing_with_issue/CMakeLists.txt index 9ce6e26082ce..fa015e4d9e94 100644 --- a/test/failing_with_issue/CMakeLists.txt +++ b/test/failing_with_issue/CMakeLists.txt @@ -4,4 +4,4 @@ tests(GROUPS failing_with_issue 3293_storage_folding_async.cpp 3357_vectorize_pred.cpp 4283_store_at_gpu.cpp - ) \ No newline at end of file + ) diff --git a/test/generator/CMakeLists.txt b/test/generator/CMakeLists.txt index 48cd8c333230..f3e90f988b93 100644 --- a/test/generator/CMakeLists.txt +++ b/test/generator/CMakeLists.txt @@ -5,7 +5,7 @@ function(halide_define_aot_test NAME) set(options OMIT_DEFAULT_GENERATOR) set(oneValueArgs FUNCTION_NAME) - set(multiValueArgs GEN_DEPS EXTRA_LIBS ENABLE_IF FEATURES PARAMS GEN_TARGET) + set(multiValueArgs GEN_DEPS EXTRA_LIBS ENABLE_IF FEATURES PARAMS GEN_TARGET GROUPS) cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) if (args_ENABLE_IF AND NOT (${args_ENABLE_IF})) @@ -37,12 +37,14 @@ function(halide_define_aot_test NAME) "${Halide_SOURCE_DIR}/tools" "${CMAKE_CURRENT_BINARY_DIR}") - add_wasm_halide_test("${TARGET}" GROUPS generator) + add_wasm_halide_test("${TARGET}" GROUPS generator "${args_GROUPS}") else () add_executable("${TARGET}" "${NAME}_aottest.cpp") - target_include_directories("${TARGET}" PRIVATE + target_include_directories( + "${TARGET}" PRIVATE "${Halide_SOURCE_DIR}/test/common" - "${Halide_SOURCE_DIR}/tools") + "${Halide_SOURCE_DIR}/tools" + ) if (NOT args_OMIT_DEFAULT_GENERATOR) target_link_libraries(${TARGET} PRIVATE ${NAME}) endif () @@ -60,45 +62,10 @@ function(halide_define_aot_test NAME) if ("${Halide_TARGET}" MATCHES "cuda") target_compile_definitions("${TARGET}" PRIVATE TEST_CUDA) endif () - add_halide_test("${TARGET}" GROUPS generator) + add_halide_test("${TARGET}" GROUPS generator "${args_GROUPS}") endif () endfunction() -## -# Define a nontrivial dependency for external_code.generator -## - -# external_code_extern.cpp -set(EXTERNAL_CPP "${CMAKE_CURRENT_SOURCE_DIR}/external_code_extern.cpp") - -set(EC32 "external_code_extern_bitcode_32") -add_custom_command(OUTPUT "${EC32}.bc" - COMMAND clang -O3 -c -m32 -target le32-unknown-nacl-unknown -emit-llvm "$" -o "${EC32}.bc" - DEPENDS "${EXTERNAL_CPP}" - VERBATIM) -add_custom_command(OUTPUT "${EC32}.cpp" - COMMAND binary2cpp external_code_extern_bitcode_32 < "${EC32}.bc" > "${EC32}.cpp" - DEPENDS "${EC32}.bc" binary2cpp - VERBATIM) - -set(EC64 "external_code_extern_bitcode_64") -add_custom_command(OUTPUT "${EC64}.bc" - COMMAND clang -O3 -c -m64 -target le64-unknown-unknown-unknown -emit-llvm "$" -o "${EC64}.bc" - DEPENDS "${EXTERNAL_CPP}" binary2cpp - VERBATIM) -add_custom_command(OUTPUT "${EC64}.cpp" - COMMAND binary2cpp external_code_extern_bitcode_64 < "${EC64}.bc" > "${EC64}.cpp" - DEPENDS "${EC64}.bc" binary2cpp - VERBATIM) - -set(ECCPP "external_code_extern_cpp_source") -add_custom_command(OUTPUT "${ECCPP}.cpp" - COMMAND binary2cpp external_code_extern_cpp_source < "$" > "${ECCPP}.cpp" - DEPENDS "${EXTERNAL_CPP}" binary2cpp - VERBATIM) - -add_library(external_code_generator_deps OBJECT "${EC32}.cpp" "${EC64}.cpp" "${ECCPP}.cpp") - ## # Some tests are not available when compiling for WASM. ## @@ -115,7 +82,7 @@ if (NOT ${USING_WASM}) endif () if (TARGET_NVPTX AND Halide_TARGET MATCHES "cuda") - include(AddCudaToTarget) + find_package(CUDAToolkit REQUIRED) endif () if (TARGET_NVPTX AND Halide_TARGET MATCHES "opencl") find_package(OpenCL REQUIRED) @@ -130,7 +97,7 @@ endif () # acquire_release_generator.cpp halide_define_aot_test(acquire_release) if (TARGET_NVPTX AND Halide_TARGET MATCHES "cuda") - add_cuda_to_target(generator_aot_acquire_release PRIVATE) + target_link_libraries(generator_aot_acquire_release PRIVATE CUDA::cuda_driver CUDA::cudart) endif () if (TARGET_NVPTX AND Halide_TARGET MATCHES "opencl") target_link_libraries(generator_aot_acquire_release PRIVATE OpenCL::OpenCL) @@ -144,10 +111,17 @@ endif () # alias_aottest.cpp # alias_generator.cpp -halide_define_aot_test(alias EXTRA_LIBS alias_with_offset_42) -add_halide_library(alias_with_offset_42 - FROM alias.generator - GENERATOR alias_with_offset_42) +set(ALIAS_LIBS alias_with_offset_42 alias_Adams2019 alias_Li2018 alias_Mullapudi2016) +halide_define_aot_test(alias EXTRA_LIBS ${ALIAS_LIBS}) +foreach (LIB IN LISTS ALIAS_LIBS) + # We don't really need all the plugins at once here -- + # It's just easier to specify them all (and adds a test that loading + # multiple plugins works) + add_halide_library(${LIB} + FROM alias.generator + GENERATOR ${LIB} + PLUGINS Halide::Adams2019 Halide::Li2018 Halide::Mullapudi2016) +endforeach () # argvcall_aottest.cpp # argvcall_generator.cpp @@ -157,11 +131,13 @@ halide_define_aot_test(argvcall) # async_parallel_generator.cpp halide_define_aot_test(async_parallel # Requires threading support, not yet available for wasm tests - ENABLE_IF NOT ${USING_WASM}) + ENABLE_IF NOT ${USING_WASM} + GROUPS multithreaded) # autograd_aottest.cpp # autograd_generator.cpp -halide_define_aot_test(autograd ENABLE_IF TARGET Halide::Mullapudi2016 AND NOT ${USING_WASM}) +halide_define_aot_test(autograd ENABLE_IF TARGET Halide::Mullapudi2016 AND NOT ${USING_WASM} + GROUPS multithreaded) if (TARGET generator_aot_autograd) add_halide_library(autograd_grad GRADIENT_DESCENT @@ -171,6 +147,10 @@ if (TARGET generator_aot_autograd) target_link_libraries(generator_aot_autograd PRIVATE autograd_grad) endif () +# abstractgeneratortest_aottest.cpp +# abstractgeneratortest_generator.cpp +halide_define_aot_test(abstractgeneratortest) + # bit_operations_aottest.cpp # bit_operations_generator.cpp halide_define_aot_test(bit_operations) @@ -183,10 +163,6 @@ halide_define_aot_test(blur2x2) # buffer_copy_generator.cpp halide_define_aot_test(buffer_copy) -# buildmethod_aottest.cpp -# buildmethod_generator.cpp -halide_define_aot_test(buildmethod) - # can_use_target_aottest.cpp # can_use_target_generator.cpp halide_define_aot_test(can_use_target) @@ -249,15 +225,13 @@ halide_define_aot_test(error_codes) # example_aottest.cpp # example_generator.cpp -halide_define_aot_test(example) +halide_define_aot_test(example + GROUPS multithreaded) # extern_output_aottest.cpp # extern_output_generator.cpp -halide_define_aot_test(extern_output) - -# external_code_aottest.cpp -# external_code_generator.cpp -halide_define_aot_test(external_code GEN_DEPS external_code_generator_deps) +halide_define_aot_test(extern_output + GROUPS multithreaded) # float16_t_aottest.cpp # float16_t_generator.cpp @@ -267,23 +241,23 @@ halide_define_aot_test(float16_t) # gpu_multi_context_threaded_generator.cpp # (Doesn't build/link properly under wasm, and isn't useful there anyway) if (NOT Halide_TARGET MATCHES "wasm") - halide_define_aot_test(gpu_multi_context_threaded - OMIT_DEFAULT_GENERATOR - EXTRA_LIBS - gpu_multi_context_threaded_add - gpu_multi_context_threaded_mul) - - add_halide_library(gpu_multi_context_threaded_add FROM gpu_multi_context_threaded.generator - FEATURES user_context) - add_halide_library(gpu_multi_context_threaded_mul FROM gpu_multi_context_threaded.generator - FEATURES user_context) - - if (TARGET_NVPTX AND Halide_TARGET MATCHES "cuda") - add_cuda_to_target(generator_aot_gpu_multi_context_threaded PRIVATE) - endif () - if (TARGET_NVPTX AND Halide_TARGET MATCHES "opencl") - target_link_libraries(generator_aot_gpu_multi_context_threaded PRIVATE OpenCL::OpenCL) - endif () + halide_define_aot_test(gpu_multi_context_threaded + OMIT_DEFAULT_GENERATOR + EXTRA_LIBS + gpu_multi_context_threaded_add + gpu_multi_context_threaded_mul) + + add_halide_library(gpu_multi_context_threaded_add FROM gpu_multi_context_threaded.generator + FEATURES user_context) + add_halide_library(gpu_multi_context_threaded_mul FROM gpu_multi_context_threaded.generator + FEATURES user_context) + + if (TARGET_NVPTX AND Halide_TARGET MATCHES "cuda") + target_link_libraries(generator_aot_gpu_multi_context_threaded PRIVATE CUDA::cuda_driver CUDA::cudart) + endif () + if (TARGET_NVPTX AND Halide_TARGET MATCHES "opencl") + target_link_libraries(generator_aot_gpu_multi_context_threaded PRIVATE OpenCL::OpenCL) + endif () endif () # gpu_object_lifetime_aottest.cpp @@ -304,24 +278,16 @@ halide_define_aot_test(image_from_array) # mandelbrot_aottest.cpp # mandelbrot_generator.cpp -halide_define_aot_test(mandelbrot) - -# matlab_aottest.cpp -# matlab_generator.cpp -halide_define_aot_test(matlab - # Needs matlab support. See https://github.com/halide/Halide/issues/2082 - ENABLE_IF NOT ${USING_WASM} - FEATURES matlab) -if (TARGET generator_aot_matlab) - set_target_properties(generator_aot_matlab PROPERTIES ENABLE_EXPORTS TRUE) -endif () +halide_define_aot_test(mandelbrot + GROUPS multithreaded) # memory_profiler_mandelbrot_aottest.cpp # memory_profiler_mandelbrot_generator.cpp halide_define_aot_test(memory_profiler_mandelbrot # Requires profiler support (which requires threading), not yet available for wasm tests ENABLE_IF NOT ${USING_WASM} - FEATURES profile) + FEATURES profile + GROUPS multithreaded) # metadata_tester_aottest.cpp # metadata_tester_generator.cpp @@ -368,14 +334,15 @@ add_halide_library(metadata_tester_ucon # msan_aottest.cpp # msan_generator.cpp -halide_define_aot_test(msan FEATURES msan) +halide_define_aot_test(msan FEATURES msan + GROUPS multithreaded) # (Doesn't build/link properly on windows / under wasm) if (NOT Halide_TARGET MATCHES "windows" AND NOT CMAKE_SYSTEM_NAME MATCHES "Windows" AND NOT Halide_TARGET MATCHES "wasm") - # sanitizercoverage_aottest.cpp - # sanitizercoverage_generator.cpp - halide_define_aot_test(sanitizercoverage FEATURES sanitizer_coverage) -endif() + # sanitizercoverage_aottest.cpp + # sanitizercoverage_generator.cpp + halide_define_aot_test(sanitizercoverage FEATURES sanitizer_coverage) +endif () # multitarget_aottest.cpp # multitarget_generator.cpp @@ -412,7 +379,8 @@ halide_define_aot_test(output_assign) # pyramid_aottest.cpp # pyramid_generator.cpp -halide_define_aot_test(pyramid PARAMS levels=10) +halide_define_aot_test(pyramid PARAMS levels=10 + GROUPS multithreaded) # rdom_input_aottest.cpp # rdom_input_generator.cpp @@ -445,15 +413,17 @@ halide_define_aot_test(tiled_blur EXTRA_LIBS blur2x2) # user_context_aottest.cpp # user_context_generator.cpp -halide_define_aot_test(user_context FEATURES user_context) +halide_define_aot_test(user_context FEATURES user_context + GROUPS multithreaded) # user_context_insanity_aottest.cpp # user_context_insanity_generator.cpp -halide_define_aot_test(user_context_insanity FEATURES user_context) +halide_define_aot_test(user_context_insanity FEATURES user_context + GROUPS multithreaded) # variable_num_threads_aottest.cpp # variable_num_threads_generator.cpp halide_define_aot_test(variable_num_threads # Requires threading support, not yet available for wasm tests - ENABLE_IF NOT ${USING_WASM}) - + ENABLE_IF NOT ${USING_WASM} + GROUPS multithreaded) diff --git a/test/generator/abstractgeneratortest_aottest.cpp b/test/generator/abstractgeneratortest_aottest.cpp new file mode 100644 index 000000000000..51290239f862 --- /dev/null +++ b/test/generator/abstractgeneratortest_aottest.cpp @@ -0,0 +1,46 @@ +#include "HalideBuffer.h" +#include "HalideRuntime.h" + +#include +#include + +#include "abstractgeneratortest.h" + +using namespace Halide::Runtime; + +const int kSize = 4; + +void verify(const Buffer &img, float compiletime_factor, float runtime_factor, int channels) { + img.for_each_element([=](int x, int y, int c) { + int expected = (int32_t)(compiletime_factor * runtime_factor * c * (x > y ? x : y)); + int actual = img(x, y, c); + assert(expected == actual); + }); +} + +int main(int argc, char **argv) { + + const int32_t scaling = 2; // GeneratorParam + + Buffer input(kSize, kSize); + const int32_t offset = 32; + + input.for_each_element([&](int x, int y) { + input(x, y) = (x + y); + }); + + Buffer output(kSize, kSize); + abstractgeneratortest(input, offset, output); + + output.for_each_element([&](int x, int y) { + int expected = (x + y) * scaling + offset; + int actual = output(x, y); + if (expected != actual) { + fprintf(stderr, "at %d %d, expected %d, actual %d\n", x, y, expected, actual); + exit(-1); + } + }); + + printf("Success!\n"); + return 0; +} diff --git a/test/generator/abstractgeneratortest_generator.cpp b/test/generator/abstractgeneratortest_generator.cpp new file mode 100644 index 000000000000..cedf16d58dd3 --- /dev/null +++ b/test/generator/abstractgeneratortest_generator.cpp @@ -0,0 +1,149 @@ +#include "Halide.h" + +#include +#include +#include +#include + +using namespace Halide::Internal; + +namespace Halide { +namespace { + +// Note to reader: this test is meant as a simple way to verify that arbitrary +// implementations of AbstractGenerator work properly. That said, we recommend +// that you don't imitate this code; AbstractGenerator is an *internal* +// abtraction, intended for Halide to build on internally. If you use AbstractGenerator +// directly, you'll almost certainly have more work maintaining your code +// on your own. + +const char *const AbstractGeneratorTestName = "abstractgeneratortest"; + +// We could use std::stoi() here, but we explicitly want to assert-fail +// if we can't parse the string as a valid int. +int string_to_int(const std::string &s) { + std::istringstream iss(s); + int i; + iss >> i; + _halide_user_assert(!iss.fail() && iss.get() == EOF) << "Unable to parse: " << s; + return i; +} + +class AbstractGeneratorTest : public AbstractGenerator { + // Boilerplate + const GeneratorContext context_; + + // Constants (aka GeneratorParams) + GeneratorParamsMap constants_ = { + {"scaling", "2"}, + }; + + // Inputs + ImageParam input_{Int(32), 2, "input"}; + Param offset_{"offset"}; + + // Outputs + Func output_{"output"}; + + // Misc + Pipeline pipeline_; + +public: + explicit AbstractGeneratorTest(const GeneratorContext &context) + : context_(context) { + } + + std::string name() override { + return AbstractGeneratorTestName; + } + + GeneratorContext context() const override { + return context_; + } + + std::vector arginfos() override { + return { + {"input", ArgInfoDirection::Input, ArgInfoKind::Buffer, {Int(32)}, 2}, + {"offset", ArgInfoDirection::Input, ArgInfoKind::Scalar, {Int(32)}, 0}, + {"output", ArgInfoDirection::Output, ArgInfoKind::Buffer, {Int(32)}, 2}, + }; + } + + void set_generatorparam_value(const std::string &name, const std::string &value) override { + _halide_user_assert(!pipeline_.defined()); + _halide_user_assert(constants_.count(name) == 1) << "Unknown Constant: " << name; + constants_[name] = value; + } + + void set_generatorparam_value(const std::string &name, const LoopLevel &value) override { + _halide_user_assert(!pipeline_.defined()); + _halide_user_assert(constants_.count(name) == 1) << "Unknown Constant: " << name; + _halide_user_assert(false) << "This Generator has no LoopLevel constants."; + } + + Pipeline build_pipeline() override { + _halide_user_assert(!pipeline_.defined()); + + const int scaling = string_to_int(constants_.at("scaling")); + + Var x, y; + output_(x, y) = input_(x, y) * scaling + offset_; + output_.compute_root(); + + pipeline_ = output_; + return pipeline_; + } + + std::vector input_parameter(const std::string &name) override { + _halide_user_assert(pipeline_.defined()); + if (name == "input") { + return {input_.parameter()}; + } + if (name == "offset") { + return {offset_.parameter()}; + } + _halide_user_assert(false) << "Unknown input: " << name; + return {}; + } + + std::vector output_func(const std::string &name) override { + _halide_user_assert(pipeline_.defined()); + if (name == "output") { + return {output_}; + } + _halide_user_assert(false) << "Unknown output: " << name; + return {}; + } + +#ifdef HALIDE_ALLOW_GENERATOR_EXTERNAL_CODE + ExternsMap external_code_map() override { + // none + return {}; + } +#endif + + void bind_input(const std::string &name, const std::vector &v) override { + _halide_user_assert(false) << "OOPS"; + } + + void bind_input(const std::string &name, const std::vector &v) override { + _halide_user_assert(false) << "OOPS"; + } + + void bind_input(const std::string &name, const std::vector &v) override { + _halide_user_assert(false) << "OOPS"; + } + + bool emit_cpp_stub(const std::string & /*stub_file_path*/) override { + // not supported + return false; + } +}; + +RegisterGenerator register_something(AbstractGeneratorTestName, + [](const GeneratorContext &context) -> AbstractGeneratorPtr { + return std::unique_ptr(new AbstractGeneratorTest(context)); + }); + +} // namespace +} // namespace Halide diff --git a/test/generator/alias_aottest.cpp b/test/generator/alias_aottest.cpp index 80c2f61a9602..41c1a9f0ae80 100644 --- a/test/generator/alias_aottest.cpp +++ b/test/generator/alias_aottest.cpp @@ -6,6 +6,13 @@ #include "alias.h" #include "alias_with_offset_42.h" +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API +// nothing +#else +#include "alias_Adams2019.h" +#include "alias_Li2018.h" +#include "alias_Mullapudi2016.h" +#endif using namespace Halide::Runtime; @@ -18,16 +25,45 @@ int main(int argc, char **argv) { input(x) = x; }); + output.fill(0); alias(input, output); + output.copy_to_host(); input.for_each_element([=](int x) { assert(output(x) == input(x)); }); + output.fill(0); alias_with_offset_42(input, output); + output.copy_to_host(); input.for_each_element([=](int x) { assert(output(x) == input(x) + 42); }); +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API + // nothing +#else + output.fill(0); + alias_Adams2019(input, output); + output.copy_to_host(); + input.for_each_element([=](int x) { + assert(output(x) == input(x) + 2019); + }); + + output.fill(0); + alias_Li2018(input, output); + output.copy_to_host(); + input.for_each_element([=](int x) { + assert(output(x) == input(x) + 2018); + }); + + output.fill(0); + output.copy_to_host(); + alias_Mullapudi2016(input, output); + input.for_each_element([=](int x) { + assert(output(x) == input(x) + 2016); + }); +#endif + printf("Success!\n"); return 0; } diff --git a/test/generator/alias_generator.cpp b/test/generator/alias_generator.cpp index 84d3e803709f..5661588229d6 100644 --- a/test/generator/alias_generator.cpp +++ b/test/generator/alias_generator.cpp @@ -11,6 +11,15 @@ class Alias : public Halide::Generator { void generate() { Var x; output(x) = input(x) + offset; + + // set estimates for the autoschedulers + input.set_estimates({{0, 32}}); + output.set_estimates({{0, 32}}); + + if (!using_autoscheduler()) { + // Don't really need a default schedule for something this simple, but sure, why not + output.vectorize(x, natural_vector_size()).compute_root(); + } } }; @@ -18,3 +27,12 @@ class Alias : public Halide::Generator { HALIDE_REGISTER_GENERATOR(Alias, alias) HALIDE_REGISTER_GENERATOR_ALIAS(alias_with_offset_42, alias, {{"offset", "42"}}) +#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API +// nothing +#else +// Since autoscheduler-to-use is now an ordinary GeneratorParam, we can specify it in Aliases for convenience. +// (Set unique offsets just to verify these are all separate calls.) +HALIDE_REGISTER_GENERATOR_ALIAS(alias_Adams2019, alias, {{"autoscheduler", "Adams2019"}, {"offset", "2019"}}) +HALIDE_REGISTER_GENERATOR_ALIAS(alias_Li2018, alias, {{"autoscheduler", "Li2018"}, {"offset", "2018"}}) +HALIDE_REGISTER_GENERATOR_ALIAS(alias_Mullapudi2016, alias, {{"autoscheduler", "Mullapudi2016"}, {"offset", "2016"}}) +#endif diff --git a/test/generator/buildmethod_aottest.cpp b/test/generator/buildmethod_aottest.cpp deleted file mode 100644 index 3eca8be57764..000000000000 --- a/test/generator/buildmethod_aottest.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include "HalideBuffer.h" -#include "HalideRuntime.h" - -#include -#include - -#include "buildmethod.h" - -using namespace Halide::Runtime; - -const int kSize = 32; - -int main(int argc, char **argv) { - Buffer input(kSize, kSize, 3); - Buffer output(kSize, kSize, 3); - - const float compiletime_factor = 1.0f; - const float runtime_factor = 3.25f; - - input.for_each_element([&](int x, int y, int c) { - input(x, y, c) = std::max(x, y) * c; - }); - - buildmethod(input, runtime_factor, output); - - output.for_each_element([=](int x, int y, int c) { - int expected = (int32_t)(compiletime_factor * runtime_factor * c * std::max(x, y)); - int actual = output(x, y, c); - assert(expected == actual); - }); - - printf("Success!\n"); - return 0; -} diff --git a/test/generator/buildmethod_generator.cpp b/test/generator/buildmethod_generator.cpp deleted file mode 100644 index d56f7db26ecc..000000000000 --- a/test/generator/buildmethod_generator.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include "Halide.h" - -namespace { - -#ifdef HALIDE_ALLOW_GENERATOR_BUILD_METHOD -// This Generator exists solely to test old-style generators (using the -// build() method, rather than generate()/schedule()). -// Do not convert it to new-style until/unless we decide to entirely remove support -// for those Generators. -class BuildMethod : public Halide::Generator { -public: - GeneratorParam compiletime_factor{"compiletime_factor", 1, 0, 100}; - - Input> input{"input"}; - Input runtime_factor{"runtime_factor", 1.0}; - - Func build() { - Var x, y, c; - - Func g; - g(x, y, c) = cast(input(x, y, c) * compiletime_factor * runtime_factor); - return g; - } -}; -#else -// Provide a placeholder here that uses generate(), just to allow this test to -// succeed even if build() is disabled. -class BuildMethod : public Halide::Generator { -public: - GeneratorParam compiletime_factor{"compiletime_factor", 1, 0, 100}; - - Input> input{"input"}; - Input runtime_factor{"runtime_factor", 1.0}; - Output> output{"output"}; - - void generate() { - Var x, y, c; - - output(x, y, c) = cast(input(x, y, c) * compiletime_factor * runtime_factor); - } -}; -#endif - -} // namespace - -HALIDE_REGISTER_GENERATOR(BuildMethod, buildmethod) diff --git a/test/generator/configure_jittest.cpp b/test/generator/configure_jittest.cpp index 68bd5955d595..8854fc0daeda 100644 --- a/test/generator/configure_jittest.cpp +++ b/test/generator/configure_jittest.cpp @@ -31,44 +31,145 @@ int main(int argc, char **argv) { extra_value += i; } + constexpr uint16_t typed_extra_value = 4; Buffer typed_extra(kSize, kSize); - typed_extra.fill(4); - extra_value += 4; - - Var x, y, c; - Func func_extra; - func_extra(x, y, c) = cast(5); - extra_value += 5; - - const int extra_scalar = 7; - const int8_t extra_dynamic_scalar = 13; - extra_value += extra_scalar + extra_dynamic_scalar; - - const int bias = 1; - auto result = configure::generate(context, configure::Inputs{ - input, - bias, - extras[0], extras[1], extras[2], - typed_extra, - func_extra, - extra_scalar, - cast(extra_dynamic_scalar)}); - - Buffer output = result.output.realize({kSize, kSize, 3}); - Buffer extra_buffer_output = result.extra_buffer_output.realize({kSize, kSize, 3}); - Buffer extra_func_output = result.extra_func_output.realize({kSize, kSize}); - - output.for_each_element([&](int x, int y, int c) { - assert(output(x, y, c) == input(x, y, c) + bias + extra_value); - }); + typed_extra.fill(typed_extra_value); + + constexpr int extra_scalar = 7; + constexpr int8_t extra_dynamic_scalar = 13; + constexpr uint16_t extra_func_value = 5; + + constexpr int bias = 1; + + extra_value += extra_scalar + extra_dynamic_scalar + extra_func_value + typed_extra_value + bias; + + // Use a Generator Stub to create the Halide IR, + // then call realize() to JIT and execute it. + { + // When calling a Stub, Func inputs must be actual Halide::Func. + Var x, y, c; + Func func_extra; + func_extra(x, y, c) = cast(extra_func_value); + + auto result = configure::generate(context, configure::Inputs{ + input, + bias, + extras[0], extras[1], extras[2], + typed_extra, + func_extra, + extra_scalar, + cast(extra_dynamic_scalar)}); + + Buffer output = result.output.realize({kSize, kSize, 3}); + Buffer extra_buffer_output = result.extra_buffer_output.realize({kSize, kSize, 3}); + Buffer extra_func_output = result.extra_func_output.realize({kSize, kSize}); + + output.for_each_element([&](int x, int y, int c) { + assert(output(x, y, c) == input(x, y, c) + extra_value); + }); + + extra_buffer_output.for_each_element([&](int x, int y, int c) { + assert(extra_buffer_output(x, y, c) == output(x, y, c)); + }); + + extra_func_output.for_each_element([&](int x, int y) { + assert(extra_func_output(x, y) == output(x, y, 0)); + }); + } - extra_buffer_output.for_each_element([&](int x, int y, int c) { - assert(extra_buffer_output(x, y, c) == output(x, y, c)); - }); + // Alternately, instead of using Generator Stubs, we can just use the Callable interface. + // We can call this on any Generator that is registered in the current process. + { + Callable configure = create_callable_from_generator(context, "configure"); + + Buffer output(kSize, kSize, 3); + Buffer extra_buffer_output(kSize, kSize, 3); + Buffer extra_func_output(kSize, kSize); + + // All inputs to a Callable must be fully realized, so any Func inputs + // that the Generator has implicitly become Buffer inputs of the same type + // and dimensionality. + Buffer func_extra(kSize, kSize, 3); + func_extra.fill(extra_func_value); + + int r = configure(input, bias, + // extra inputs are in the order they were added, after all predeclared inputs + extras[0], extras[1], extras[2], + typed_extra, + func_extra, + extra_scalar, + extra_dynamic_scalar, + output, + // extra outputs are in the order they were added, after all predeclared outputs + extra_buffer_output, + extra_func_output); + assert(r == 0); + + output.for_each_element([&](int x, int y, int c) { + assert(output(x, y, c) == input(x, y, c) + extra_value); + }); + + extra_buffer_output.for_each_element([&](int x, int y, int c) { + assert(extra_buffer_output(x, y, c) == output(x, y, c)); + }); + + extra_func_output.for_each_element([&](int x, int y) { + assert(extra_func_output(x, y) == output(x, y, 0)); + }); + } - extra_func_output.for_each_element([&](int x, int y) { - assert(extra_func_output(x, y) == output(x, y, 0)); - }); + // We can also make an explicitly-typed std::function if we prefer. + { + auto configure = create_callable_from_generator(context, "configure") + .make_std_function< + Buffer, + int32_t, + Buffer, + Buffer, + Buffer, + Buffer, + Buffer, + int32_t, + int8_t, + Buffer, + Buffer, + Buffer>(); + + Buffer output(kSize, kSize, 3); + Buffer extra_buffer_output(kSize, kSize, 3); + Buffer extra_func_output(kSize, kSize); + + // All inputs to a Callable must be fully realized, so any Func inputs + // that the Generator has implicitly become Buffer inputs of the same type + // and dimensionality. + Buffer func_extra(kSize, kSize, 3); + func_extra.fill(extra_func_value); + + int r = configure(input, bias, + // extra inputs are in the order they were added, after all predeclared inputs + extras[0], extras[1], extras[2], + typed_extra, + func_extra, + extra_scalar, + extra_dynamic_scalar, + output, + // extra outputs are in the order they were added, after all predeclared outputs + extra_buffer_output, + extra_func_output); + assert(r == 0); + + output.for_each_element([&](int x, int y, int c) { + assert(output(x, y, c) == input(x, y, c) + extra_value); + }); + + extra_buffer_output.for_each_element([&](int x, int y, int c) { + assert(extra_buffer_output(x, y, c) == output(x, y, c)); + }); + + extra_func_output.for_each_element([&](int x, int y) { + assert(extra_func_output(x, y) == output(x, y, 0)); + }); + } printf("Success!\n"); return 0; diff --git a/test/generator/error_codes_aottest.cpp b/test/generator/error_codes_aottest.cpp index 4368ceb99423..4e62f23f87c0 100644 --- a/test/generator/error_codes_aottest.cpp +++ b/test/generator/error_codes_aottest.cpp @@ -7,7 +7,7 @@ void my_halide_error(void *user_context, const char *msg) { // Silently drop the error - //printf("%s\n", msg); + // printf("%s\n", msg); } void check(int result, int correct) { diff --git a/test/generator/example_generator.cpp b/test/generator/example_generator.cpp index 41ab28e8da2d..9997b6ccfcad 100644 --- a/test/generator/example_generator.cpp +++ b/test/generator/example_generator.cpp @@ -81,7 +81,7 @@ class Example : public Halide::Generator { runtime_factor.set_estimate(1); output.set_estimates({{0, 32}, {0, 32}, {0, 3}}); - if (!auto_schedule) { + if (!using_autoscheduler()) { output .bound(c, 0, channels) .reorder(c, x, y) diff --git a/test/generator/example_jittest.cpp b/test/generator/example_jittest.cpp index ac20c755a774..7596c4c27b68 100644 --- a/test/generator/example_jittest.cpp +++ b/test/generator/example_jittest.cpp @@ -65,6 +65,29 @@ int main(int argc, char **argv) { verify(img, 1.f, runtime_factor, 3); } + { + // Alternately, instead of using Generator Stubs, we can just use the Callable interface. + // We can call this on any Generator that is registered in the current process. + Callable example = create_callable_from_generator(context, "example"); + + Buffer img(kSize, kSize, 3); + int r = example(runtime_factor, img); + assert(r == 0); + + verify(img, 1.f, runtime_factor, 3); + } + + { + // We can also make an explicitly-typed std::function if we prefer: + auto example = create_callable_from_generator(context, "example").make_std_function>(); + + Buffer img(kSize, kSize, 3); + int r = example(runtime_factor, img); + assert(r == 0); + + verify(img, 1.f, runtime_factor, 3); + } + printf("Success!\n"); return 0; } diff --git a/test/generator/extern_output_aottest.cpp b/test/generator/extern_output_aottest.cpp index 7cfab92b2643..6156f313c0ef 100644 --- a/test/generator/extern_output_aottest.cpp +++ b/test/generator/extern_output_aottest.cpp @@ -6,13 +6,7 @@ using namespace Halide::Runtime; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - -extern "C" DLLEXPORT int extern_stage(halide_buffer_t *input, int addend, halide_buffer_t *output) { +extern "C" HALIDE_EXPORT_SYMBOL int extern_stage(halide_buffer_t *input, int addend, halide_buffer_t *output) { // Note the final output buffer argument is unused. if (input->is_bounds_query()) { for (int d = 0; d < 2; d++) { diff --git a/test/generator/external_code_aottest.cpp b/test/generator/external_code_aottest.cpp deleted file mode 100644 index cd03f04b2692..000000000000 --- a/test/generator/external_code_aottest.cpp +++ /dev/null @@ -1,35 +0,0 @@ -#include "HalideBuffer.h" -#include "HalideRuntime.h" - -#include -#include -#include -#include - -#include "external_code.h" - -using namespace std; -using namespace Halide::Runtime; - -int main() { - Buffer buf(10, 10); - - for (int i = 0; i < 10; i++) { - for (int j = 0; j < 10; j++) { - buf(i, j) = i * 65536 + j * 256; - } - } - - Buffer out(10, 10); - int ret_code = external_code(buf.raw_buffer(), out.raw_buffer()); - - assert(ret_code == 0); - - for (int i = 0; i < 10; i++) { - for (int j = 0; j < 10; j++) { - assert(out(i, j) == i * 65536 + j * 256 + 42); - } - } - printf("Success!\n"); - return 0; -} diff --git a/test/generator/external_code_extern.cpp b/test/generator/external_code_extern.cpp deleted file mode 100644 index e1915f3d3f0c..000000000000 --- a/test/generator/external_code_extern.cpp +++ /dev/null @@ -1,3 +0,0 @@ -extern "C" float gen_extern_tester(float in) { - return in + 42; -} diff --git a/test/generator/external_code_generator.cpp b/test/generator/external_code_generator.cpp deleted file mode 100644 index 49c5f4ce06f3..000000000000 --- a/test/generator/external_code_generator.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include "Halide.h" - -extern "C" unsigned char external_code_extern_bitcode_32[]; -extern "C" int external_code_extern_bitcode_32_length; -extern "C" unsigned char external_code_extern_bitcode_64[]; -extern "C" int external_code_extern_bitcode_64_length; -extern "C" unsigned char external_code_extern_cpp_source[]; -extern "C" int external_code_extern_cpp_source_length; - -namespace { - -class ExternalCode : public Halide::Generator { -public: - GeneratorParam external_code_is_bitcode{"external_code_is_bitcode", true}; - Input> input{"input"}; - Output> output{"output"}; - HalidePureExtern_1(float, gen_extern_tester, float); - - void generate() { - Var x("x"), y("y"); - Func f("f"); - - unsigned char *code; - int code_length; - const char *name = "org.halide-lang.extern_code_extern"; - if (external_code_is_bitcode) { - Target target = get_target(); - if (target.bits == 64) { - code = external_code_extern_bitcode_64; - code_length = external_code_extern_bitcode_64_length; - } else { - code = external_code_extern_bitcode_32; - code_length = external_code_extern_bitcode_32_length; - } - std::vector code_vector(code, code + code_length); - get_externs_map()->insert({name, - Halide::ExternalCode::bitcode_wrapper(target, code_vector, name)}); - } else { - code = external_code_extern_cpp_source; - code_length = external_code_extern_cpp_source_length; - std::vector code_vector(code, code + code_length); - get_externs_map()->insert({name, - Halide::ExternalCode::c_plus_plus_code_wrapper(code_vector, name)}); - } - - output(x, y) = gen_extern_tester(cast(input(x, y))); - } - - void schedule() { - } -}; - -} // namespace - -HALIDE_REGISTER_GENERATOR(ExternalCode, external_code) diff --git a/test/generator/float16_t_aottest.cpp b/test/generator/float16_t_aottest.cpp index e03ce1cdf7a2..34369bb15b0d 100644 --- a/test/generator/float16_t_aottest.cpp +++ b/test/generator/float16_t_aottest.cpp @@ -64,7 +64,7 @@ int main() { (1.0f) / (1 << 23), // 0x1.000000p-23 (-1.0f) / (1 << 23), // -0x1.000000p-23 (1.5f) / (1 << 23), // 0x1.800000p-23 - float_from_bits(0x387fc000), //0x1.ff8000p-15, + float_from_bits(0x387fc000), // 0x1.ff8000p-15, float_from_bits(0x387f8000), // 0x1.ff0000p-15, 1.0f, -1.0f}; diff --git a/test/generator/matlab_aottest.cpp b/test/generator/matlab_aottest.cpp deleted file mode 100644 index f439d6ea2670..000000000000 --- a/test/generator/matlab_aottest.cpp +++ /dev/null @@ -1,206 +0,0 @@ -#include "HalideRuntime.h" - -#include -#include -#include -#include - -// Provide a simple mock implementation of matlab's API so we can test the mexFunction. - -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - -enum mxClassID { - mxSINGLE_CLASS = 7, - mxINT32_CLASS = 12, -}; - -enum mxComplexity { - mxREAL = 0, - mxCOMPLEX, -}; - -template -mxClassID get_class_id(); -template<> -mxClassID get_class_id() { - return mxSINGLE_CLASS; -} -template<> -mxClassID get_class_id() { - return mxINT32_CLASS; -} - -class mxArray { -public: - virtual void *get_data() = 0; - virtual const void *get_data() const = 0; - virtual const size_t *get_dimensions() const = 0; - virtual size_t get_number_of_dimensions() const = 0; - virtual mxClassID get_class_id() const = 0; - virtual double get_scalar() const = 0; - virtual size_t get_element_size() const = 0; - - virtual ~mxArray() { - } -}; - -template -class mxArrayImpl : public mxArray { - std::vector data; - std::vector dims; - -public: - mxArrayImpl(size_t M, size_t N) - : data(M * N), dims({M, N}) { - } - - void *get_data() override { - return &data[0]; - } - const void *get_data() const override { - return &data[0]; - } - const size_t *get_dimensions() const override { - return &dims[0]; - } - size_t get_number_of_dimensions() const override { - return dims.size(); - } - mxClassID get_class_id() const override { - return ::get_class_id(); - } - double get_scalar() const override { - return data[0]; - } - size_t get_element_size() const override { - return sizeof(T); - } - - T &operator()(int i, int j) { - return data[i * dims[0] + j]; - } - T operator()(int i, int j) const { - return data[i * dims[0] + j]; - } -}; - -extern "C" { - -DLLEXPORT int mexWarnMsgTxt(const char *msg) { - // Don't bother with the varargs. - printf("%s\n", msg); - return 0; -} - -DLLEXPORT size_t mxGetNumberOfDimensions_730(const mxArray *a) { - return a->get_number_of_dimensions(); -} - -DLLEXPORT int mxGetNumberOfDimensions_700(const mxArray *a) { - return (int)a->get_number_of_dimensions(); -} - -DLLEXPORT const size_t *mxGetDimensions_730(const mxArray *a) { - return a->get_dimensions(); -} - -DLLEXPORT const int *mxGetDimensions_700(const mxArray *a) { - assert(sizeof(size_t) == sizeof(int)); - return reinterpret_cast(a->get_dimensions()); -} - -DLLEXPORT mxClassID mxGetClassID(const mxArray *a) { - return a->get_class_id(); -} - -DLLEXPORT void *mxGetData(const mxArray *a) { - return const_cast(a)->get_data(); -} - -DLLEXPORT size_t mxGetElementSize(const mxArray *a) { - return a->get_element_size(); -} - -// We only support real, numeric classes in this mock implementation. -DLLEXPORT bool mxIsNumeric(const mxArray *a) { - return true; -} -DLLEXPORT bool mxIsLogical(const mxArray *a) { - return false; -} -DLLEXPORT bool mxIsComplex(const mxArray *a) { - return false; -} - -DLLEXPORT double mxGetScalar(const mxArray *a) { - return a->get_scalar(); -} - -DLLEXPORT mxArray *mxCreateNumericMatrix_730(size_t M, size_t N, mxClassID type, mxComplexity complexity) { - assert(complexity == mxREAL); - switch (type) { - case mxSINGLE_CLASS: - return new mxArrayImpl(M, N); - case mxINT32_CLASS: - return new mxArrayImpl(M, N); - default: - return nullptr; - } -} - -DLLEXPORT mxArray *mxCreateNumericMatrix_700(int M, int N, mxClassID type, mxComplexity complexity) { - return mxCreateNumericMatrix_730(M, N, type, complexity); -} - -void mexFunction(int, mxArray **, int, mxArray **); -} - -int main(int argc, char **argv) { - mxArray *lhs[1] = {nullptr}; - mxArray *rhs[4] = { - nullptr, - }; - - mxArrayImpl input(3, 5); - mxArrayImpl scale(1, 1); - mxArrayImpl negate(1, 1); - mxArrayImpl output(3, 5); - - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 5; j++) { - input(i, j) = (float)(i * 5 + j); - } - } - - scale(0, 0) = 3.0f; - negate(0, 0) = 1; - - rhs[0] = &input; - rhs[1] = &scale; - rhs[2] = &negate; - rhs[3] = &output; - - mexFunction(1, lhs, 4, rhs); - - assert(lhs[0]->get_scalar() == 0); - delete lhs[0]; - lhs[0] = nullptr; - - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 5; j++) { - float in = input(i, j); - float expected = in * scale(0, 0) * (negate(0, 0) ? -1.0f : 1.0f); - if (output(i, j) == expected) { - printf("output(%d, %d) = %f instead of %f\n", - i, j, output(i, j), expected); - } - } - } - - printf("Success!\n"); - return 0; -} diff --git a/test/generator/matlab_generator.cpp b/test/generator/matlab_generator.cpp deleted file mode 100644 index 1dbcc80e3cfb..000000000000 --- a/test/generator/matlab_generator.cpp +++ /dev/null @@ -1,22 +0,0 @@ -#include "Halide.h" - -using namespace Halide; - -namespace { - -class Matlab : public Halide::Generator { -public: - Input> input{"input"}; - Input scale{"scale"}; - Input negate{"negate"}; - Output> output{"output"}; - - void generate() { - Var x, y; - output(x, y) = input(x, y) * scale * select(negate, -1.0f, 1.0f); - } -}; - -} // namespace - -HALIDE_REGISTER_GENERATOR(Matlab, matlab) diff --git a/test/generator/msan_generator.cpp b/test/generator/msan_generator.cpp index 5945054cf3e7..ca12ef30fc74 100644 --- a/test/generator/msan_generator.cpp +++ b/test/generator/msan_generator.cpp @@ -28,6 +28,9 @@ class MSAN : public Halide::Generator { input.dim(0).set_stride(Expr()).set_extent(4).dim(1).set_extent(4).dim(2).set_extent(3); output.parallel(y).vectorize(x, 4); output.dim(0).set_stride(Expr()).set_extent(4).dim(1).set_extent(4).dim(2).set_extent(3); + // Silence warnings. + output.update(0).unscheduled(); + output.update(1).unscheduled(); } private: diff --git a/test/generator/stubtest_generator.cpp b/test/generator/stubtest_generator.cpp index 08d7a6e6751e..8f5b41640e41 100644 --- a/test/generator/stubtest_generator.cpp +++ b/test/generator/stubtest_generator.cpp @@ -78,7 +78,7 @@ class StubTest : public Halide::Generator { // Verify that Output::type() and ::dims() are well-defined after we define the Func assert(tuple_output.types()[0] == Float(32)); assert(tuple_output.types()[1] == Float(32)); - assert(tuple_output.dims() == 3); + assert(tuple_output.dimensions() == 3); array_output.resize(array_input.size()); for (size_t i = 0; i < array_input.size(); ++i) { @@ -92,8 +92,10 @@ class StubTest : public Halide::Generator { } void schedule() { - intermediate.compute_at(intermediate_level); - intermediate.specialize(vectorize).vectorize(x, natural_vector_size()); + if (!using_autoscheduler()) { + intermediate.compute_at(intermediate_level); + intermediate.specialize(vectorize).vectorize(x, natural_vector_size()); + } } private: diff --git a/test/generator/stubtest_jittest.cpp b/test/generator/stubtest_jittest.cpp index 9fc4b43b9f20..1c0aa3f8fc14 100644 --- a/test/generator/stubtest_jittest.cpp +++ b/test/generator/stubtest_jittest.cpp @@ -23,11 +23,11 @@ Buffer make_image(int extra) { return im; } -template -void verify(const Buffer &input, float float_arg, int int_arg, const Buffer &output) { +template +void verify(const Buffer &input, float float_arg, int int_arg, const Buffer &output) { if (input.width() != output.width() || input.height() != output.height()) { - fprintf(stderr, "size mismatch\n"); + fprintf(stderr, "size mismatch: %dx%d vs %dx%d\n", input.width(), input.height(), output.width(), output.height()); exit(-1); } int channels = std::max(1, std::min(input.channels(), output.channels())); @@ -37,8 +37,7 @@ void verify(const Buffer &input, float float_arg, int int_arg, con const OutputType expected = static_cast(input(x, y, c) * float_arg + int_arg); const OutputType actual = output(x, y, c); if (expected != actual) { - fprintf(stderr, "img[%d, %d, %d] = %f, expected %f\n", x, y, c, (double)actual, (double)expected); - abort(); + fprintf(stderr, "img[%d, %d, %d] = %f, expected %f (input = %f)\n", x, y, c, (double)actual, (double)expected, (double)input(x, y, c)); exit(-1); } } @@ -60,56 +59,230 @@ int main(int argc, char **argv) { // the Stub wants Expr, so make a conversion in place std::vector int_args_expr(int_args.begin(), int_args.end()); - // Pass in a set of GeneratorParams: even though we aren't customizing - // the values, we can set the LoopLevel values after-the-fact. - StubTest::GeneratorParams gp; - auto gen = StubTest::generate( - GeneratorContext(get_jit_target_from_environment()), - // Use aggregate-initialization syntax to fill in an Inputs struct. - { - buffer_input, // typed_buffer_input - buffer_input, // untyped_buffer_input - {buffer_input, buffer_input}, - Func(simple_input), - {Func(array_input[0]), Func(array_input[1])}, - 1.25f, - int_args_expr}, - gp); + GeneratorContext context(get_jit_target_from_environment()); + + { + // Pass in a set of GeneratorParams: even though we aren't customizing + // the values, we can set the LoopLevel values after-the-fact. + StubTest::GeneratorParams gp; + auto gen = StubTest::generate( + context, + // Use aggregate-initialization syntax to fill in an Inputs struct. + { + buffer_input, // typed_buffer_input + buffer_input, // untyped_buffer_input + {buffer_input, buffer_input}, + Func(simple_input), + {Func(array_input[0]), Func(array_input[1])}, + 1.25f, + int_args_expr}, + gp); + + gp.intermediate_level.set(LoopLevel(gen.tuple_output, gen.tuple_output.args().at(1))); - gp.intermediate_level.set(LoopLevel(gen.tuple_output, gen.tuple_output.args().at(1))); + Realization simple_output_realized = gen.simple_output.realize({kSize, kSize, 3}); + Buffer s0 = simple_output_realized; + verify(array_input[0], 1.f, 0, s0); - Realization simple_output_realized = gen.simple_output.realize({kSize, kSize, 3}); - Buffer s0 = simple_output_realized; - verify(array_input[0], 1.f, 0, s0); + Realization tuple_output_realized = gen.tuple_output.realize({kSize, kSize, 3}); + Buffer f0 = tuple_output_realized[0]; + Buffer f1 = tuple_output_realized[1]; + verify(array_input[0], 1.25f, 0, f0); + verify(array_input[0], 1.25f, 33, f1); + + for (int i = 0; i < kArrayCount; ++i) { + Realization array_output_realized = gen.array_output[i].realize({kSize, kSize, 3}, gen.target); + Buffer g0 = array_output_realized; + verify(array_input[i], 1.0f, int_args[i], g0); + } - Realization tuple_output_realized = gen.tuple_output.realize({kSize, kSize, 3}); - Buffer f0 = tuple_output_realized[0]; - Buffer f1 = tuple_output_realized[1]; - verify(array_input[0], 1.25f, 0, f0); - verify(array_input[0], 1.25f, 33, f1); + Realization typed_buffer_output_realized = gen.typed_buffer_output.realize({kSize, kSize, 3}); + Buffer b0 = typed_buffer_output_realized; + verify(buffer_input, 1.f, 0, b0); - for (int i = 0; i < kArrayCount; ++i) { - Realization array_output_realized = gen.array_output[i].realize({kSize, kSize, 3}, gen.target); - Buffer g0 = array_output_realized; - verify(array_input[i], 1.0f, int_args[i], g0); + Realization untyped_buffer_output_realized = gen.untyped_buffer_output.realize({kSize, kSize, 3}); + Buffer b1 = untyped_buffer_output_realized; + verify(buffer_input, 1.f, 0, b1); + + Realization static_compiled_buffer_output_realized = gen.static_compiled_buffer_output.realize({kSize, kSize, 3}); + Buffer b2 = static_compiled_buffer_output_realized; + verify(buffer_input, 1.f, 42, b2); + + for (int i = 0; i < 2; ++i) { + Realization array_buffer_output_realized = gen.array_buffer_output[i].realize({kSize, kSize, 3}); + Buffer b2 = array_buffer_output_realized; + verify(buffer_input, 1.f, 1 + i, b2); + } } - Realization typed_buffer_output_realized = gen.typed_buffer_output.realize({kSize, kSize, 3}); - Buffer b0 = typed_buffer_output_realized; - verify(buffer_input, 1.f, 0, b0); + // Alternately, instead of using Generator Stubs, we can just use the Callable interface. + // We can call this on any Generator that is registered in the current process. + { + Buffer buffer_input = make_image(0); + Buffer simple_input = make_image(0); + Buffer array_input0 = make_image(0); + Buffer array_input1 = make_image(1); + Buffer typed_buffer_output(kSize, kSize, 3); + Buffer untyped_buffer_output(kSize, kSize, 3); + Buffer tupled_output0(kSize, kSize, 3); + Buffer tupled_output1(kSize, kSize, 3); + Buffer array_buffer_input0 = make_image(0); + Buffer array_buffer_input1 = make_image(1); + Buffer simple_output(kSize, kSize, 3); + // TODO: see Issues #3709, #3967 + Buffer float16_output(halide_type_t(halide_type_float, 16), kSize, kSize, 3); + Buffer bfloat16_output(halide_type_t(halide_type_bfloat, 16), kSize, kSize, 3); + Buffer tuple_output0(kSize, kSize, 3), tuple_output1(kSize, kSize, 3); + Buffer array_output0(kSize, kSize, 3), array_output1(kSize, kSize, 3); + Buffer static_compiled_buffer_output(kSize, kSize, 3); + Buffer array_buffer_output0(kSize, kSize, 3), array_buffer_output1(kSize, kSize, 3); + + // Note that this Generator has several GeneratorParams that need to be set correctly + // before compilation -- in the Stub case above, the values end up being inferred + // from the specific inputs we provide, but for the JIT (and AOT) cases, there are + // no such inputs available, so we must be explicit. (Note that these are the same + // values specified in our Make/CMake files.) + const GeneratorParamsMap gp = { + {"untyped_buffer_input.type", "uint8"}, + {"untyped_buffer_input.dim", "3"}, + {"simple_input.type", "float32"}, + {"array_input.type", "float32"}, + {"array_input.size", "2"}, + {"int_arg.size", "2"}, + {"tuple_output.type", "float32,float32"}, + {"vectorize", "true"}, + }; - Realization untyped_buffer_output_realized = gen.untyped_buffer_output.realize({kSize, kSize, 3}); - Buffer b1 = untyped_buffer_output_realized; - verify(buffer_input, 1.f, 0, b1); + Callable stubtest = create_callable_from_generator(context, "stubtest", gp); - Realization static_compiled_buffer_output_realized = gen.static_compiled_buffer_output.realize({kSize, kSize, 3}); - Buffer b2 = static_compiled_buffer_output_realized; - verify(buffer_input, 1.f, 42, b2); + int r = stubtest( + buffer_input, + buffer_input, + array_buffer_input0, array_buffer_input1, + simple_input, + array_input0, array_input1, + 1.25f, + 33, + 66, + simple_output, + tuple_output0, tuple_output1, + array_output0, array_output1, + typed_buffer_output, + untyped_buffer_output, + tupled_output0, tupled_output1, + static_compiled_buffer_output, + array_buffer_output0, array_buffer_output1, + float16_output, + bfloat16_output); + assert(r == 0); + + verify(buffer_input, 1.f, 0, typed_buffer_output); + verify(buffer_input, 1.f, 0, untyped_buffer_output); + verify(simple_input, 1.f, 0, simple_output); + verify(simple_input, 1.f, 0, tupled_output0); + verify(simple_input, 1.f, 1, tupled_output1); + verify(array_input0, 1.f, 0, simple_output); + verify(array_input0, 1.25f, 0, tuple_output0); + verify(array_input0, 1.25f, 33, tuple_output1); + verify(array_input0, 1.0f, 33, array_output0); + verify(array_input1, 1.0f, 66, array_output1); + verify(buffer_input, 1.0f, 42, static_compiled_buffer_output); + verify(array_buffer_input0, 1.f, 1, array_buffer_output0); + verify(array_buffer_input1, 1.f, 2, array_buffer_output1); + } + + // We can also make an explicitly-typed std::function if we prefer. + { + Buffer buffer_input = make_image(0); + Buffer simple_input = make_image(0); + Buffer array_input0 = make_image(0); + Buffer array_input1 = make_image(1); + Buffer typed_buffer_output(kSize, kSize, 3); + Buffer untyped_buffer_output(kSize, kSize, 3); + Buffer tupled_output0(kSize, kSize, 3); + Buffer tupled_output1(kSize, kSize, 3); + Buffer array_buffer_input0 = make_image(0); + Buffer array_buffer_input1 = make_image(1); + Buffer simple_output(kSize, kSize, 3); + // TODO: see Issues #3709, #3967 + Buffer float16_output(halide_type_t(halide_type_float, 16), kSize, kSize, 3); + Buffer bfloat16_output(halide_type_t(halide_type_bfloat, 16), kSize, kSize, 3); + Buffer tuple_output0(kSize, kSize, 3), tuple_output1(kSize, kSize, 3); + Buffer array_output0(kSize, kSize, 3), array_output1(kSize, kSize, 3); + Buffer static_compiled_buffer_output(kSize, kSize, 3); + Buffer array_buffer_output0(kSize, kSize, 3), array_buffer_output1(kSize, kSize, 3); + + // Note that this Generator has several GeneratorParams that need to be set correctly + // before compilation -- in the Stub case above, the values end up being inferred + // from the specific inputs we provide, but for the JIT (and AOT) cases, there are + // no such inputs available, so we must be explicit. (Note that these are the same + // values specified in our Make/CMake files.) + const GeneratorParamsMap gp = { + {"untyped_buffer_input.type", "uint8"}, + {"untyped_buffer_input.dim", "3"}, + {"simple_input.type", "float32"}, + {"array_input.type", "float32"}, + {"array_input.size", "2"}, + {"int_arg.size", "2"}, + {"tuple_output.type", "float32,float32"}, + {"vectorize", "true"}, + }; + + auto stubtest = create_callable_from_generator(context, "stubtest", gp) + .make_std_function< + Buffer, + Buffer, + Buffer, Buffer, + Buffer, + Buffer, Buffer, + float, + int32_t, + int32_t, + Buffer, + Buffer, Buffer, + Buffer, Buffer, + Buffer, + Buffer, + Buffer, Buffer, + Buffer, + Buffer, Buffer, + Buffer, + Buffer>(); + + int r = stubtest( + buffer_input, + buffer_input, + array_buffer_input0, array_buffer_input1, + simple_input, + array_input0, array_input1, + 1.25f, + 33, + 66, + simple_output, + tuple_output0, tuple_output1, + array_output0, array_output1, + typed_buffer_output, + untyped_buffer_output, + tupled_output0, tupled_output1, + static_compiled_buffer_output, + array_buffer_output0, array_buffer_output1, + float16_output, + bfloat16_output); + assert(r == 0); - for (int i = 0; i < 2; ++i) { - Realization array_buffer_output_realized = gen.array_buffer_output[i].realize({kSize, kSize, 3}); - Buffer b2 = array_buffer_output_realized; - verify(buffer_input, 1.f, 1 + i, b2); + verify(buffer_input, 1.f, 0, typed_buffer_output); + verify(buffer_input, 1.f, 0, untyped_buffer_output); + verify(simple_input, 1.f, 0, simple_output); + verify(simple_input, 1.f, 0, tupled_output0); + verify(simple_input, 1.f, 1, tupled_output1); + verify(array_input0, 1.f, 0, simple_output); + verify(array_input0, 1.25f, 0, tuple_output0); + verify(array_input0, 1.25f, 33, tuple_output1); + verify(array_input0, 1.0f, 33, array_output0); + verify(array_input1, 1.0f, 66, array_output1); + verify(buffer_input, 1.0f, 42, static_compiled_buffer_output); + verify(array_buffer_input0, 1.f, 1, array_buffer_output0); + verify(array_buffer_input1, 1.f, 2, array_buffer_output1); } printf("Success!\n"); diff --git a/test/generator/stubuser_aottest.cpp b/test/generator/stubuser_aottest.cpp index 11646b643091..9a03308934e0 100644 --- a/test/generator/stubuser_aottest.cpp +++ b/test/generator/stubuser_aottest.cpp @@ -2,8 +2,10 @@ #include "HalideBuffer.h" #include "HalideRuntime.h" +#include "halide_benchmark.h" #include "stubuser.h" +#include "stubuser_auto.h" using namespace Halide::Runtime; @@ -62,15 +64,23 @@ int main(int argc, char **argv) { Buffer float16_output(halide_type_t(halide_type_float, 16), kSize, kSize, 3); Buffer bfloat16_output(halide_type_t(halide_type_bfloat, 16), kSize, kSize, 3); - stubuser(input, calculated_output, float32_buffer_output, int32_buffer_output, - array_test_output, tupled_output0, tupled_output1, int_output, - float16_output, bfloat16_output); - verify(input, kFloatArg, kIntArg, kOffset, calculated_output); - verify(input, 1.f, 0, 0.f, float32_buffer_output); - verify(input, 1.f, 0, 0.f, int32_buffer_output); - verify(input, 1.f, 0, 2, array_test_output); - verify(input, 1.f, 0, 0, tupled_output0); - verify(input, 1.f, 1, 3, int_output); + struct FnInfo { + decltype(&stubuser) f; + const char *const name; + }; + FnInfo fns[2] = {{stubuser, "stubuser"}, {stubuser_auto, "stubuser_auto"}}; + for (auto f : fns) { + printf("Testing %s...\n", f.name); + f.f(input, calculated_output, float32_buffer_output, int32_buffer_output, + array_test_output, tupled_output0, tupled_output1, int_output, + float16_output, bfloat16_output); + verify(input, kFloatArg, kIntArg, kOffset, calculated_output); + verify(input, 1.f, 0, 0.f, float32_buffer_output); + verify(input, 1.f, 0, 0.f, int32_buffer_output); + verify(input, 1.f, 0, 2, array_test_output); + verify(input, 1.f, 0, 0, tupled_output0); + verify(input, 1.f, 1, 3, int_output); + } printf("Success!\n"); return 0; diff --git a/test/generator/stubuser_generator.cpp b/test/generator/stubuser_generator.cpp index 7fbe0c2137ac..f5400f95912a 100644 --- a/test/generator/stubuser_generator.cpp +++ b/test/generator/stubuser_generator.cpp @@ -94,6 +94,18 @@ class StubUser : public Halide::Generator { extra_scalar, cast(extra_dynamic_scalar)}) .output; + + // Estimates (for autoscheduler): + constexpr int kSize = 32; + input.set_estimates({{0, kSize}, {0, kSize}, {0, 3}}); + calculated_output.set_estimates({{0, kSize}, {0, kSize}, {0, 3}}); + float32_buffer_output.set_estimates({{0, kSize}, {0, kSize}, {0, 3}}); + int32_buffer_output.set_estimates({{0, kSize}, {0, kSize}, {0, 3}}); + array_test_output.set_estimates({{0, kSize}, {0, kSize}, {0, 3}}); + tupled_output.set_estimates({{0, kSize}, {0, kSize}, {0, 3}}); + int_output.set_estimates({{0, kSize}, {0, kSize}, {0, 3}}); + float16_output.set_estimates({{0, kSize}, {0, kSize}, {0, 3}}); + bfloat16_output.set_estimates({{0, kSize}, {0, kSize}, {0, 3}}); } }; diff --git a/test/integration/CMakeLists.txt b/test/integration/CMakeLists.txt index 9b1abbb6a60c..600b1d765a36 100644 --- a/test/integration/CMakeLists.txt +++ b/test/integration/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(integration_tests NONE) enable_testing() diff --git a/test/integration/aot/CMakeLists.txt b/test/integration/aot/CMakeLists.txt index 4e6ae6d2b3d0..b370a642339b 100644 --- a/test/integration/aot/CMakeLists.txt +++ b/test/integration/aot/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(aot) enable_testing() diff --git a/test/integration/jit/CMakeLists.txt b/test/integration/jit/CMakeLists.txt index c8d5a6546a96..a6f24342184a 100644 --- a/test/integration/jit/CMakeLists.txt +++ b/test/integration/jit/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(jit) enable_testing() diff --git a/test/integration/xc/CMakeLists.txt b/test/integration/xc/CMakeLists.txt index 0019a702bdac..8552e9cefc62 100644 --- a/test/integration/xc/CMakeLists.txt +++ b/test/integration/xc/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.22) project(xc) enable_testing() diff --git a/test/internal.cpp b/test/internal.cpp index 5baf0dab47a2..08283fa9cf54 100644 --- a/test/internal.cpp +++ b/test/internal.cpp @@ -16,6 +16,7 @@ #include "Monotonic.h" #include "Reduction.h" #include "Solve.h" +#include "SpirvIR.h" #include "UniquifyVariableNames.h" using namespace Halide; @@ -39,6 +40,7 @@ int main(int argc, const char **argv) { generator_test(); propagate_estimate_test(); uniquify_variable_names_test(); + spirv_ir_test(); printf("Success!\n"); return 0; diff --git a/test/performance/CMakeLists.txt b/test/performance/CMakeLists.txt index c317326fbe40..d1e869d97f07 100644 --- a/test/performance/CMakeLists.txt +++ b/test/performance/CMakeLists.txt @@ -1,35 +1,45 @@ +if (Halide_ANY_SANITIZERS_ENABLED) + # All sanitizers impact performance, so don't even bother with this test suite + message(STATUS "Skipping all performance testing because at least one Sanitizer is enabled.") + return() +endif() + tests(GROUPS performance SOURCES - tiled_matmul.cpp async_gpu.cpp block_transpose.cpp boundary_conditions.cpp clamped_vector_load.cpp const_division.cpp - fan_in.cpp fast_inverse.cpp fast_pow.cpp fast_sine_cosine.cpp gpu_half_throughput.cpp - inner_loop_parallel.cpp jit_stress.cpp lots_of_inputs.cpp - lots_of_small_allocations.cpp - matrix_multiplication.cpp memcpy.cpp - memory_profiler.cpp nested_vectorization_gemm.cpp packed_planar_fusion.cpp + realize_overhead.cpp + rgb_interleaved.cpp + tiled_matmul.cpp + vectorize.cpp + wrap.cpp + ) + +tests(GROUPS performance multithreaded + SOURCES + fan_in.cpp + inner_loop_parallel.cpp + lots_of_small_allocations.cpp + matrix_multiplication.cpp + memory_profiler.cpp parallel_performance.cpp profiler.cpp - realize_overhead.cpp rfactor.cpp - rgb_interleaved.cpp - stack_vs_heap.cpp sort.cpp + stack_vs_heap.cpp thread_safe_jit.cpp - vectorize.cpp - wrap.cpp ) # Make sure that performance tests do not run in parallel with other tests, diff --git a/test/performance/const_division.cpp b/test/performance/const_division.cpp index 2ddad95cf56a..f5f245dd0376 100644 --- a/test/performance/const_division.cpp +++ b/test/performance/const_division.cpp @@ -86,7 +86,6 @@ bool test(int w, bool div, bool round_to_zero) { h.vectorize(x); } Target t = get_jit_target_from_environment(); - t.set_feature(Target::DisableLLVMLoopOpt); f.compile_jit(t); g.compile_jit(t); h.compile_jit(t); diff --git a/test/performance/fast_pow.cpp b/test/performance/fast_pow.cpp index d2c1de9fec57..706b435ddcf5 100644 --- a/test/performance/fast_pow.cpp +++ b/test/performance/fast_pow.cpp @@ -6,14 +6,8 @@ using namespace Halide; using namespace Halide::Tools; -#ifdef _WIN32 -#define DLLEXPORT __declspec(dllexport) -#else -#define DLLEXPORT -#endif - // powf() is a macro in some environments, so always wrap it -extern "C" DLLEXPORT float pow_ref(float x, float y) { +extern "C" HALIDE_EXPORT_SYMBOL float pow_ref(float x, float y) { return powf(x, y); } HalideExtern_2(float, pow_ref, float, float); diff --git a/test/performance/nested_vectorization_gemm.cpp b/test/performance/nested_vectorization_gemm.cpp index 25a0bc746fb1..1c17e965d3c2 100644 --- a/test/performance/nested_vectorization_gemm.cpp +++ b/test/performance/nested_vectorization_gemm.cpp @@ -10,9 +10,6 @@ int main(int argc, char **argv) { printf("[SKIP] Performance tests are meaningless and/or misleading under WebAssembly interpreter.\n"); return 0; } - // We don't want to be sensitive to LLVM pulling the same tricks - // or not. - target.set_feature(Target::DisableLLVMLoopOpt); // 8-bit mat-mul into 32-bit accumulator { @@ -286,7 +283,7 @@ int main(int argc, char **argv) { Buffer out(f_buf.width() - g_buf.width() - 128); // Uncomment to check the asm - //result.compile_to_assembly("/dev/stdout", {f, g}, target); + // result.compile_to_assembly("/dev/stdout", {f, g}, target); times[use_nested_vectorization] = Tools::benchmark(10, 10, [&]() { diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index 2fd90683bd38..03bd243ef554 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -140,8 +140,8 @@ bool matmul(Halide::Target target) { Func result = mm.in(); // Uncomment to check the asm - //result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.ll", {A, B}, target); - //result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); + // result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.ll", {A, B}, target); + // result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); auto time = Tools::benchmark(20, 20, [&]() { result.realize(out); @@ -222,8 +222,8 @@ bool matmul_bf16(Halide::Target target) { Buffer out(col, row); // Uncomment to check the asm - //result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul_bf16.ll", {A, B}, target); - //result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); + // result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul_bf16.ll", {A, B}, target); + // result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); auto time = Tools::benchmark(20, 20, [&]() { result.realize(out); diff --git a/test/runtime/CMakeLists.txt b/test/runtime/CMakeLists.txt new file mode 100644 index 000000000000..54c219ffa392 --- /dev/null +++ b/test/runtime/CMakeLists.txt @@ -0,0 +1,32 @@ +function(halide_define_runtime_internal_test NAME) + add_executable(runtime_internal_${NAME} ${NAME}.cpp) + target_link_libraries(runtime_internal_${NAME} PRIVATE Halide::Test) + target_include_directories(runtime_internal_${NAME} PRIVATE "${Halide_SOURCE_DIR}/src") + target_include_directories(runtime_internal_${NAME} PRIVATE "${Halide_SOURCE_DIR}/src/runtime") + target_link_libraries(runtime_internal_${NAME} PRIVATE Halide::Runtime) + if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + # Halide runtime lib has declarations for memcmp etc that conflict with GNU stdlib + target_compile_options(runtime_internal_${NAME} PRIVATE -Wno-builtin-declaration-mismatch ) + endif() + target_compile_definitions( + runtime_internal_${NAME} + PRIVATE + HALIDE_VERSION=${Halide_VERSION} + HALIDE_VERSION_MAJOR=${Halide_VERSION_MAJOR} + HALIDE_VERSION_MINOR=${Halide_VERSION_MINOR} + HALIDE_VERSION_PATCH=${Halide_VERSION_PATCH} + COMPILING_HALIDE_RUNTIME + COMPILING_HALIDE_RUNTIME_TESTS + ) + add_halide_test(runtime_internal_${NAME} GROUPS runtime_internal) +endfunction() + +# NOTE: These tests directly include runtime_internal.h which isn't compatible with MSVC +if(NOT MSVC) + halide_define_runtime_internal_test(block_allocator) + halide_define_runtime_internal_test(block_storage) + halide_define_runtime_internal_test(linked_list) + halide_define_runtime_internal_test(memory_arena) + halide_define_runtime_internal_test(string_storage) + halide_define_runtime_internal_test(string_table) +endif() \ No newline at end of file diff --git a/test/runtime/block_allocator.cpp b/test/runtime/block_allocator.cpp new file mode 100644 index 000000000000..d147f652e80d --- /dev/null +++ b/test/runtime/block_allocator.cpp @@ -0,0 +1,153 @@ +#include "common.h" + +#include "internal/block_allocator.h" +#include "internal/pointer_table.h" + +using namespace Halide::Runtime::Internal; + +namespace { + +size_t allocated_region_memory = 0; +size_t allocated_block_memory = 0; + +void allocate_block(void *user_context, MemoryBlock *block) { + block->handle = allocate_system(user_context, block->size); + allocated_block_memory += block->size; + + debug(user_context) << "Test : allocate_block (" + << "block=" << (void *)(block) << " " + << "block_size=" << int32_t(block->size) << " " + << "allocated_block_memory=" << int32_t(allocated_block_memory) << " " + << ") !\n"; +} + +void deallocate_block(void *user_context, MemoryBlock *block) { + deallocate_system(user_context, block->handle); + allocated_block_memory -= block->size; + + debug(user_context) << "Test : deallocate_block (" + << "block=" << (void *)(block) << " " + << "block_size=" << int32_t(block->size) << " " + << "allocated_block_memory=" << int32_t(allocated_block_memory) << " " + << ") !\n"; +} + +void allocate_region(void *user_context, MemoryRegion *region) { + region->handle = (void *)1; + allocated_region_memory += region->size; + + debug(user_context) << "Test : allocate_region (" + << "region=" << (void *)(region) << " " + << "region_size=" << int32_t(region->size) << " " + << "allocated_region_memory=" << int32_t(allocated_region_memory) << " " + << ") !\n"; +} + +void deallocate_region(void *user_context, MemoryRegion *region) { + region->handle = (void *)0; + allocated_region_memory -= region->size; + + debug(user_context) << "Test : deallocate_region (" + << "region=" << (void *)(region) << " " + << "region_size=" << int32_t(region->size) << " " + << "allocated_region_memory=" << int32_t(allocated_region_memory) << " " + << ") !\n"; +} + +} // end namespace + +int main(int argc, char **argv) { + void *user_context = (void *)1; + + SystemMemoryAllocatorFns system_allocator = {allocate_system, deallocate_system}; + MemoryBlockAllocatorFns block_allocator = {allocate_block, deallocate_block}; + MemoryRegionAllocatorFns region_allocator = {allocate_region, deallocate_region}; + + // test class interface + { + BlockAllocator::Config config = {0}; + config.minimum_block_size = 1024; + + BlockAllocator::MemoryAllocators allocators = {system_allocator, block_allocator, region_allocator}; + BlockAllocator *instance = BlockAllocator::create(user_context, config, allocators); + + MemoryRequest request = {0}; + request.size = sizeof(int); + request.alignment = sizeof(int); + request.properties.visibility = MemoryVisibility::DefaultVisibility; + request.properties.caching = MemoryCaching::DefaultCaching; + request.properties.usage = MemoryUsage::DefaultUsage; + + MemoryRegion *r1 = instance->reserve(user_context, request); + halide_abort_if_false(user_context, r1 != nullptr); + halide_abort_if_false(user_context, allocated_block_memory == config.minimum_block_size); + halide_abort_if_false(user_context, allocated_region_memory == request.size); + + MemoryRegion *r2 = instance->reserve(user_context, request); + halide_abort_if_false(user_context, r2 != nullptr); + halide_abort_if_false(user_context, allocated_block_memory == config.minimum_block_size); + halide_abort_if_false(user_context, allocated_region_memory == (2 * request.size)); + + instance->reclaim(user_context, r1); + halide_abort_if_false(user_context, allocated_region_memory == (1 * request.size)); + + instance->destroy(user_context); + debug(user_context) << "Test : block_allocator::destroy (" + << "allocated_block_memory=" << int32_t(allocated_block_memory) << " " + << "allocated_region_memory=" << int32_t(allocated_region_memory) << " " + << ") !\n"; + + halide_abort_if_false(user_context, allocated_block_memory == 0); + halide_abort_if_false(user_context, allocated_region_memory == 0); + + BlockAllocator::destroy(user_context, instance); + + debug(user_context) << "Test : block_allocator::destroy (" + << "allocated_system_memory=" << int32_t(allocated_system_memory) << " " + << ") !\n"; + + halide_abort_if_false(user_context, allocated_system_memory == 0); + } + + // stress test + { + BlockAllocator::Config config = {0}; + config.minimum_block_size = 1024; + + BlockAllocator::MemoryAllocators allocators = {system_allocator, block_allocator, region_allocator}; + BlockAllocator *instance = BlockAllocator::create(user_context, config, allocators); + + MemoryRequest request = {0}; + request.size = sizeof(int); + request.alignment = sizeof(int); + request.properties.visibility = MemoryVisibility::DefaultVisibility; + request.properties.caching = MemoryCaching::DefaultCaching; + request.properties.usage = MemoryUsage::DefaultUsage; + + static size_t test_allocations = 1000; + PointerTable pointers(user_context, test_allocations, system_allocator); + for (size_t n = 0; n < test_allocations; ++n) { + size_t count = n % 32; + count = count > 1 ? count : 1; + request.size = count * sizeof(int); + MemoryRegion *region = instance->reserve(user_context, request); + pointers.append(user_context, region); + } + + for (size_t n = 0; n < pointers.size(); ++n) { + MemoryRegion *region = static_cast(pointers[n]); + instance->reclaim(user_context, region); + } + halide_abort_if_false(user_context, allocated_region_memory == 0); + + pointers.destroy(user_context); + instance->destroy(user_context); + halide_abort_if_false(user_context, allocated_block_memory == 0); + + BlockAllocator::destroy(user_context, instance); + halide_abort_if_false(user_context, allocated_system_memory == 0); + } + + print(user_context) << "Success!\n"; + return 0; +} diff --git a/test/runtime/block_storage.cpp b/test/runtime/block_storage.cpp new file mode 100644 index 000000000000..ad7499f84378 --- /dev/null +++ b/test/runtime/block_storage.cpp @@ -0,0 +1,148 @@ +#include "common.h" + +#include "internal/block_storage.h" + +using namespace Halide::Runtime::Internal; + +struct TestStruct { + int8_t i8; + uint16_t ui16; + float f32; +}; + +template +T read_as(const BlockStorage &bs, size_t index) { + const T *ptr = static_cast(bs[index]); + return *ptr; +} + +int main(int argc, char **argv) { + void *user_context = (void *)1; + + // test class interface + { + BlockStorage::Config config = BlockStorage::default_config(); + config.entry_size = sizeof(int); + + BlockStorage bs(user_context, config); + bs.reserve(user_context, 256); + halide_abort_if_false(user_context, bs.size() == 0); + + int a1[4] = {12, 34, 56, 78}; + bs.append(user_context, &a1[0]); + halide_abort_if_false(user_context, bs.size() == 1); + halide_abort_if_false(user_context, read_as(bs, 0) == a1[0]); + + bs.append(user_context, &a1[1]); + halide_abort_if_false(user_context, bs.size() == 2); + halide_abort_if_false(user_context, read_as(bs, 1) == a1[1]); + + bs.insert(user_context, 1, &a1[2]); + halide_abort_if_false(user_context, bs.size() == 3); + halide_abort_if_false(user_context, read_as(bs, 0) == a1[0]); + halide_abort_if_false(user_context, read_as(bs, 1) == a1[2]); // inserted here + halide_abort_if_false(user_context, read_as(bs, 2) == a1[1]); + + bs.prepend(user_context, &a1[3]); + halide_abort_if_false(user_context, bs.size() == 4); + halide_abort_if_false(user_context, read_as(bs, 0) == a1[3]); + + int a2[] = {98, 76, 54, 32, 10}; + size_t a2_size = 5; + bs.fill(user_context, a2, a2_size); + halide_abort_if_false(user_context, bs.size() == a2_size); + halide_abort_if_false(user_context, read_as(bs, 0) == a2[0]); + halide_abort_if_false(user_context, read_as(bs, 1) == a2[1]); + halide_abort_if_false(user_context, read_as(bs, 2) == a2[2]); + halide_abort_if_false(user_context, read_as(bs, 3) == a2[3]); + halide_abort_if_false(user_context, read_as(bs, 4) == a2[4]); + + int a3[] = {77, 66, 55}; + size_t a3_size = 3; + bs.insert(user_context, 2, a3, a3_size); + halide_abort_if_false(user_context, bs.size() == (a2_size + a3_size)); + halide_abort_if_false(user_context, read_as(bs, 0) == a2[0]); + halide_abort_if_false(user_context, read_as(bs, 1) == a2[1]); + halide_abort_if_false(user_context, read_as(bs, 2) == a3[0]); // a3 inserted here + halide_abort_if_false(user_context, read_as(bs, 3) == a3[1]); + halide_abort_if_false(user_context, read_as(bs, 4) == a3[2]); + halide_abort_if_false(user_context, read_as(bs, 5) == a2[2]); // a2 resumes here + halide_abort_if_false(user_context, read_as(bs, 6) == a2[3]); + halide_abort_if_false(user_context, read_as(bs, 7) == a2[4]); + + bs.pop_front(user_context); + bs.pop_front(user_context); + + bs.pop_back(user_context); + bs.pop_back(user_context); + + halide_abort_if_false(user_context, bs.size() == (a2_size + a3_size - 4)); + halide_abort_if_false(user_context, read_as(bs, 0) == a3[0]); + halide_abort_if_false(user_context, read_as(bs, 1) == a3[1]); + halide_abort_if_false(user_context, read_as(bs, 2) == a3[2]); + halide_abort_if_false(user_context, read_as(bs, 3) == a2[2]); + + bs.clear(user_context); + halide_abort_if_false(user_context, bs.size() == 0); + } + + // test copy and equality + { + BlockStorage::Config config = BlockStorage::default_config(); + config.entry_size = sizeof(int); + + int a1[] = {98, 76, 54, 32, 10}; + size_t a1_size = 5; + + int a2[] = {77, 66, 55}; + size_t a2_size = 3; + + BlockStorage bs1(user_context, config); + bs1.fill(user_context, a1, a1_size); + + BlockStorage bs2(user_context, config); + bs2.fill(user_context, a2, a2_size); + + BlockStorage bs3(bs1); + + halide_abort_if_false(user_context, bs1.size() == (a1_size)); + halide_abort_if_false(user_context, bs2.size() == (a2_size)); + halide_abort_if_false(user_context, bs3.size() == bs1.size()); + + halide_abort_if_false(user_context, bs1 != bs2); + halide_abort_if_false(user_context, bs1 == bs3); + + bs2 = bs1; + halide_abort_if_false(user_context, bs1 == bs2); + } + + // test struct storage + { + BlockStorage::Config config = BlockStorage::default_config(); + config.entry_size = sizeof(TestStruct); + + BlockStorage bs(user_context, config); + halide_abort_if_false(user_context, bs.size() == 0); + + TestStruct s1 = {8, 16, 32.0f}; + bs.append(user_context, &s1); + halide_abort_if_false(user_context, bs.size() == 1); + + const TestStruct e1 = read_as(bs, 0); + halide_abort_if_false(user_context, e1.i8 == s1.i8); + halide_abort_if_false(user_context, e1.ui16 == s1.ui16); + halide_abort_if_false(user_context, e1.f32 == s1.f32); + + TestStruct s2 = {1, 2, 3.0f}; + bs.prepend(user_context, &s2); + halide_abort_if_false(user_context, bs.size() == 2); + + const TestStruct e2 = read_as(bs, 0); + halide_abort_if_false(user_context, e2.i8 == s2.i8); + halide_abort_if_false(user_context, e2.ui16 == s2.ui16); + halide_abort_if_false(user_context, e2.f32 == s2.f32); + } + + print(user_context) << "Success!\n"; + return 0; +} diff --git a/test/runtime/common.h b/test/runtime/common.h new file mode 100644 index 000000000000..d4158c67a743 --- /dev/null +++ b/test/runtime/common.h @@ -0,0 +1,80 @@ +#include +#include + +#include "HalideRuntime.h" +#include "msan_stubs.cpp" +#include "runtime_internal.h" +#include "to_string.cpp" + +extern "C" { + +extern int printf(const char *format, ...); + +void halide_print(void *user_context, const char *str) { + printf("%s", str); +} + +void halide_error(void *user_context, const char *msg) { + halide_print(user_context, msg); +} + +void halide_profiler_report(void *user_context) { +} + +void halide_profiler_reset() { +} + +} // extern "C" + +#include "printer.h" + +namespace { + +size_t allocated_system_memory = 0; + +void *align_up(void *ptr, size_t offset, size_t alignment) { + return (void *)(((((size_t)ptr + offset)) + (alignment - 1)) & ~(alignment - 1)); +} + +void *allocate_system(void *user_context, size_t bytes) { + constexpr size_t alignment = 128; + constexpr size_t header_size = 2 * sizeof(size_t); + size_t alloc_size = bytes + header_size + (alignment - 1); + void *raw_ptr = malloc(alloc_size); + if (raw_ptr == nullptr) { + return nullptr; + } + void *aligned_ptr = align_up(raw_ptr, header_size, alignment); + size_t aligned_offset = (size_t)((size_t)aligned_ptr - (size_t)raw_ptr); + *((size_t *)aligned_ptr - 1) = aligned_offset; + *((size_t *)aligned_ptr - 2) = alloc_size; + allocated_system_memory += alloc_size; + + debug(user_context) << "Test : allocate_system (" + << "ptr=" << (void *)(raw_ptr) << " " + << "aligned_ptr=" << (void *)(aligned_ptr) << " " + << "aligned_offset=" << int32_t(aligned_offset) << " " + << "alloc_size=" << int32_t(alloc_size) << " " + << "allocated_system_memory=" << int32_t(allocated_system_memory) << " " + << ") !\n"; + + return aligned_ptr; +} + +void deallocate_system(void *user_context, void *aligned_ptr) { + size_t aligned_offset = *((size_t *)aligned_ptr - 1); + size_t alloc_size = *((size_t *)aligned_ptr - 2); + void *raw_ptr = (void *)((uint8_t *)aligned_ptr - aligned_offset); + free(raw_ptr); + allocated_system_memory -= alloc_size; + + debug(user_context) << "Test : deallocate_system (" + << "ptr=" << (void *)(raw_ptr) << " " + << "aligned_ptr=" << (void *)(aligned_ptr) << " " + << "aligned_offset=" << int32_t(aligned_offset) << " " + << "alloc_size=" << int32_t(alloc_size) << " " + << "allocated_system_memory=" << int32_t(allocated_system_memory) << " " + << ") !\n"; +} + +} // anonymous namespace diff --git a/test/runtime/linked_list.cpp b/test/runtime/linked_list.cpp new file mode 100644 index 000000000000..807406ed4d20 --- /dev/null +++ b/test/runtime/linked_list.cpp @@ -0,0 +1,104 @@ +#include "common.h" + +#include "internal/linked_list.h" + +using namespace Halide::Runtime::Internal; + +struct TestStruct { + int8_t i8; + uint16_t ui16; + float f32; +}; + +template +T read_as(const LinkedList::EntryType *entry_ptr) { + const T *ptr = static_cast(entry_ptr->value); + return *ptr; +} + +int main(int argc, char **argv) { + void *user_context = (void *)1; + SystemMemoryAllocatorFns test_allocator = {allocate_system, deallocate_system}; + + // test class interface + { + LinkedList list(user_context, sizeof(int), 64, test_allocator); + halide_abort_if_false(user_context, list.size() == 0); + + const int i0 = 12; + list.append(user_context, &i0); // contents: 12 + halide_abort_if_false(user_context, list.size() == 1); + halide_abort_if_false(user_context, (list.front() != nullptr)); + halide_abort_if_false(user_context, (list.back() != nullptr)); + halide_abort_if_false(user_context, read_as(list.front()) == i0); + halide_abort_if_false(user_context, read_as(list.back()) == i0); + + const int i1 = 34; + list.append(user_context, &i1); // contents: 12, 34 + halide_abort_if_false(user_context, list.size() == 2); + halide_abort_if_false(user_context, read_as(list.back()) == i1); + + const int i2 = 56; + list.insert_before(user_context, list.back(), &i2); // contents: 12, 56, 34 + halide_abort_if_false(user_context, list.size() == 3); + halide_abort_if_false(user_context, read_as(list.back()) == i1); + + const int i3 = 78; + list.prepend(user_context, &i3); // contents: 78, 12, 56, 34 + halide_abort_if_false(user_context, list.size() == 4); + halide_abort_if_false(user_context, read_as(list.front()) == i3); + halide_abort_if_false(user_context, read_as(list.back()) == i1); + + list.pop_front(user_context); // contents: 12, 56, 34 + halide_abort_if_false(user_context, list.size() == 3); + halide_abort_if_false(user_context, read_as(list.front()) == i0); + halide_abort_if_false(user_context, read_as(list.back()) == i1); + + list.pop_back(user_context); // contents: 12, 56 + halide_abort_if_false(user_context, list.size() == 2); + halide_abort_if_false(user_context, read_as(list.front()) == i0); + halide_abort_if_false(user_context, read_as(list.back()) == i2); + + list.clear(user_context); + halide_abort_if_false(user_context, list.size() == 0); + + size_t count = 4 * 1024; + for (size_t n = 0; n < count; ++n) { + list.append(user_context, &n); + } + halide_abort_if_false(user_context, list.size() == count); + + list.destroy(user_context); + halide_abort_if_false(user_context, allocated_system_memory == 0); + } + + // test struct storage + { + LinkedList list(user_context, sizeof(TestStruct), 32, test_allocator); + halide_abort_if_false(user_context, list.size() == 0); + + TestStruct s1 = {8, 16, 32.0f}; + list.append(user_context, &s1); + halide_abort_if_false(user_context, list.size() == 1); + + const TestStruct e1 = read_as(list.front()); + halide_abort_if_false(user_context, e1.i8 == s1.i8); + halide_abort_if_false(user_context, e1.ui16 == s1.ui16); + halide_abort_if_false(user_context, e1.f32 == s1.f32); + + TestStruct s2 = {1, 2, 3.0f}; + list.prepend(user_context, &s2); + halide_abort_if_false(user_context, list.size() == 2); + + TestStruct e2 = read_as(list.front()); + halide_abort_if_false(user_context, e2.i8 == s2.i8); + halide_abort_if_false(user_context, e2.ui16 == s2.ui16); + halide_abort_if_false(user_context, e2.f32 == s2.f32); + + list.destroy(user_context); + halide_abort_if_false(user_context, allocated_system_memory == 0); + } + + print(user_context) << "Success!\n"; + return 0; +} diff --git a/test/runtime/memory_arena.cpp b/test/runtime/memory_arena.cpp new file mode 100644 index 000000000000..338d190cea39 --- /dev/null +++ b/test/runtime/memory_arena.cpp @@ -0,0 +1,96 @@ +#include "common.h" + +#include "internal/memory_arena.h" + +using namespace Halide::Runtime::Internal; + +struct TestStruct { + int8_t i8; + uint16_t ui16; + float f32; +}; + +int main(int argc, char **argv) { + void *user_context = (void *)1; + + // test class interface + { + SystemMemoryAllocatorFns test_allocator = {allocate_system, deallocate_system}; + + MemoryArena::Config config = {sizeof(int), 32, 0}; + MemoryArena arena(user_context, config, test_allocator); + void *p1 = arena.reserve(user_context); + halide_abort_if_false(user_context, allocated_system_memory >= (1 * sizeof(int))); + halide_abort_if_false(user_context, p1 != nullptr); + + void *p2 = arena.reserve(user_context, true); + halide_abort_if_false(user_context, allocated_system_memory >= (2 * sizeof(int))); + halide_abort_if_false(user_context, p2 != nullptr); + halide_abort_if_false(user_context, (*static_cast(p2)) == 0); + + arena.reclaim(user_context, p1); + arena.destroy(user_context); + + halide_abort_if_false(user_context, allocated_system_memory == 0); + } + + // test dyanmic construction + { + SystemMemoryAllocatorFns test_allocator = {allocate_system, deallocate_system}; + + MemoryArena::Config config = {sizeof(double), 32, 0}; + MemoryArena *arena = MemoryArena::create(user_context, config, test_allocator); + + size_t count = 4 * 1024; + void *pointers[count]; + for (size_t n = 0; n < count; ++n) { + pointers[n] = arena->reserve(user_context, true); + } + halide_abort_if_false(user_context, allocated_system_memory >= (count * sizeof(int))); + for (size_t n = 0; n < count; ++n) { + void *ptr = pointers[n]; + halide_abort_if_false(user_context, ptr != nullptr); + halide_abort_if_false(user_context, (*static_cast(ptr)) == 0.0); + } + arena->destroy(user_context); + + MemoryArena::destroy(user_context, arena); + halide_abort_if_false(user_context, allocated_system_memory == 0); + } + + // test struct allocations + { + SystemMemoryAllocatorFns test_allocator = {allocate_system, deallocate_system}; + MemoryArena::Config config = {sizeof(TestStruct), 32, 0}; + MemoryArena arena(user_context, config, test_allocator); + void *s1 = arena.reserve(user_context, true); + halide_abort_if_false(user_context, s1 != nullptr); + halide_abort_if_false(user_context, allocated_system_memory >= (1 * sizeof(int))); + halide_abort_if_false(user_context, ((TestStruct *)s1)->i8 == int8_t(0)); + halide_abort_if_false(user_context, ((TestStruct *)s1)->ui16 == uint16_t(0)); + halide_abort_if_false(user_context, ((TestStruct *)s1)->f32 == float(0)); + + arena.destroy(user_context); + + size_t count = 4 * 1024; + void *pointers[count]; + for (size_t n = 0; n < count; ++n) { + pointers[n] = arena.reserve(user_context, true); + } + + for (size_t n = 0; n < count; ++n) { + void *s1 = pointers[n]; + halide_abort_if_false(user_context, s1 != nullptr); + halide_abort_if_false(user_context, ((TestStruct *)s1)->i8 == int8_t(0)); + halide_abort_if_false(user_context, ((TestStruct *)s1)->ui16 == uint16_t(0)); + halide_abort_if_false(user_context, ((TestStruct *)s1)->f32 == float(0)); + } + + arena.destroy(user_context); + + halide_abort_if_false(user_context, allocated_system_memory == 0); + } + + print(user_context) << "Success!\n"; + return 0; +} diff --git a/test/runtime/string_storage.cpp b/test/runtime/string_storage.cpp new file mode 100644 index 000000000000..a557f9af2c19 --- /dev/null +++ b/test/runtime/string_storage.cpp @@ -0,0 +1,72 @@ +#include "common.h" + +#include "internal/string_storage.h" + +using namespace Halide::Runtime::Internal; + +int main(int argc, char **argv) { + void *user_context = (void *)1; + SystemMemoryAllocatorFns test_allocator = {allocate_system, deallocate_system}; + + // test class interface + { + StringStorage ss(user_context, 0, test_allocator); + halide_abort_if_false(user_context, ss.length() == 0); + + const char *ts1 = "Testing!"; + const size_t ts1_length = strlen(ts1); + ss.assign(user_context, ts1); + halide_abort_if_false(user_context, ss.length() == ts1_length); + halide_abort_if_false(user_context, ss.contains(ts1)); + + const char *ts2 = "More "; + const size_t ts2_length = strlen(ts2); + ss.prepend(user_context, ts2); + halide_abort_if_false(user_context, ss.length() == (ts1_length + ts2_length)); + halide_abort_if_false(user_context, ss.contains(ts2)); + halide_abort_if_false(user_context, ss.contains(ts1)); + + ss.append(user_context, '!'); + halide_abort_if_false(user_context, ss.length() == (ts1_length + ts2_length + 1)); + + ss.clear(user_context); + halide_abort_if_false(user_context, ss.length() == 0); + + ss.destroy(user_context); + halide_abort_if_false(user_context, allocated_system_memory == 0); + } + + // test copy and equality + { + const char *ts1 = "Test One!"; + const size_t ts1_length = strlen(ts1); + + const char *ts2 = "Test Two!"; + const size_t ts2_length = strlen(ts2); + + StringStorage ss1(user_context, 0, test_allocator); + ss1.assign(user_context, ts1, ts1_length); + + StringStorage ss2(user_context, 0, test_allocator); + ss2.assign(user_context, ts2, ts2_length); + + StringStorage ss3(ss1); + + halide_abort_if_false(user_context, ss1.length() == (ts1_length)); + halide_abort_if_false(user_context, ss2.length() == (ts2_length)); + halide_abort_if_false(user_context, ss3.length() == ss1.length()); + + halide_abort_if_false(user_context, ss1 != ss2); + halide_abort_if_false(user_context, ss1 == ss3); + + ss2 = ss1; + halide_abort_if_false(user_context, ss1 == ss2); + + ss1.destroy(user_context); + ss2.destroy(user_context); + ss3.destroy(user_context); + halide_abort_if_false(user_context, allocated_system_memory == 0); + } + print(user_context) << "Success!\n"; + return 0; +} diff --git a/test/runtime/string_table.cpp b/test/runtime/string_table.cpp new file mode 100644 index 000000000000..6fef995aa73b --- /dev/null +++ b/test/runtime/string_table.cpp @@ -0,0 +1,45 @@ +#include "common.h" + +#include "internal/string_table.h" + +using namespace Halide::Runtime::Internal; + +int main(int argc, char **argv) { + void *user_context = (void *)1; + SystemMemoryAllocatorFns test_allocator = {allocate_system, deallocate_system}; + + // test class interface + { + size_t data_size = 4; + const char *data[] = { + "one", "two", "three", "four"}; + + StringTable st1(user_context, 0, test_allocator); + halide_abort_if_false(user_context, st1.size() == 0); + + st1.fill(user_context, data, data_size); + halide_abort_if_false(user_context, st1.size() == data_size); + halide_abort_if_false(user_context, strncmp(st1[0], data[0], strlen(data[0])) == 0); + halide_abort_if_false(user_context, strncmp(st1[1], data[1], strlen(data[1])) == 0); + halide_abort_if_false(user_context, strncmp(st1[2], data[2], strlen(data[2])) == 0); + halide_abort_if_false(user_context, strncmp(st1[3], data[3], strlen(data[3])) == 0); + halide_abort_if_false(user_context, st1.contains(data[0])); + halide_abort_if_false(user_context, st1.contains(data[1])); + halide_abort_if_false(user_context, st1.contains(data[2])); + halide_abort_if_false(user_context, st1.contains(data[3])); + + st1.clear(user_context); + halide_abort_if_false(user_context, st1.size() == 0); + + size_t entry_count = st1.parse(user_context, "one:two:three:four", ":"); + halide_abort_if_false(user_context, entry_count == data_size); + halide_abort_if_false(user_context, st1.size() == data_size); + halide_abort_if_false(user_context, st1.contains(data[0])); + halide_abort_if_false(user_context, st1.contains(data[1])); + halide_abort_if_false(user_context, st1.contains(data[2])); + halide_abort_if_false(user_context, st1.contains(data[3])); + } + + print(user_context) << "Success!\n"; + return 0; +} diff --git a/tools/GenGen.cpp b/tools/GenGen.cpp index 5ff83e4796c5..c6192c16179e 100644 --- a/tools/GenGen.cpp +++ b/tools/GenGen.cpp @@ -1,5 +1,5 @@ #include "Halide.h" int main(int argc, char **argv) { - return Halide::Internal::generate_filter_main(argc, argv, std::cerr); + return Halide::Internal::generate_filter_main(argc, argv); } diff --git a/tools/halide_image_io.h b/tools/halide_image_io.h index 7cfa04a860b6..a8cb5f7d293a 100644 --- a/tools/halide_image_io.h +++ b/tools/halide_image_io.h @@ -2347,7 +2347,7 @@ class load_and_convert_image { // a runtime error will occur. template void save_image(ImageType &im, const std::string &filename) { - auto im_d = im.template as(); + auto im_d = im.template as(); (void)save(im_d, filename); } diff --git a/tools/mex_halide.m b/tools/mex_halide.m deleted file mode 100644 index 4c6f342e2729..000000000000 --- a/tools/mex_halide.m +++ /dev/null @@ -1,114 +0,0 @@ -function mex_halide( generator_filename, varargin ) -%mex_halide - Create a mex library from a Halide generator source -%file. -% -% generator_filename identifies a C++ source file containing a generator. -% The remaining arguments are a list of name-value pairs of the form -% 'generator_param=value' used to assign the generator params, or -% additional flags: -% -e : Which outputs to emit from the -% generator, multiply outputs can be specified with a comma -% delimited list. -% -c : Which C++ compiler to use to build the -% generator. Default is 'c++'. -% -g : Which generator to build. If only one generator -% is registered, it will be used by default. -% -% If a target is specified by a generator param with target=..., the -% 'matlab' feature flag must be present. -% -% This script uses two environment variables that can optionally be -% set or changed: -% - HALIDE_DISTRIB_PATH: The path to the distrib directory of Halide. If -% unspecified, this defaults to '../../distrib' relative to mex_halide.m. -% - HALIDE_CXX: The C++ compiler to use to build generators. The -% default is 'c++'. - - gengen_cpp = ['#include "Halide.h"', sprintf('\n'), ... - 'int main(int argc, char **argv) {', ... - ' return Halide::Internal::generate_filter_main(argc, argv, std::cerr);', ... - '}']; - - % Make a temporary directory for our intermediates. - temp = fullfile(tempdir, 'mex_halide'); - if ~exist(temp, 'dir') - mkdir(temp); - end - - % Write the generator main program to a temporary file. - gengen_filename = fullfile(temp, 'GenGen.cpp'); - gengen_file = fopen(gengen_filename, 'w'); - fprintf(gengen_file, '%s', gengen_cpp); - fclose(gengen_file); - - % Build the filenames of the intermediate object we will generate. - [path, filename] = fileparts(generator_filename); - object_file = fullfile(temp, [filename, '.o']); - function_name = filename; - - % Concatenate the generator args into a single string. - generator_args = strjoin(varargin); - target = 'host-matlab'; - - if isempty(getenv('HALIDE_DISTRIB_PATH')) - % If the user has not set the halide path, get the path of - % this file (presumably in $HALIDE_DISTRIB_PATH/tools/) and use - % that. - [path, ~] = fileparts(mfilename('fullpath')); - halide_distrib_path = fullfile(path, '..'); - setenv('HALIDE_DISTRIB_PATH', halide_distrib_path); - end - halide_distrib_path = getenv('HALIDE_DISTRIB_PATH'); - - if ismac - libhalide = fullfile(halide_distrib_path, 'lib', 'libHalide.dylib'); - else - libhalide = fullfile(halide_distrib_path, 'lib', 'libHalide.so'); - end - halide_include = fullfile(halide_distrib_path, 'include'); - - if isempty(getenv('HALIDE_CXX')) - % If the user has not set a compiler for Halide, use c++. - setenv('HALIDE_CXX', 'c++'); - end - halide_cxx = getenv('HALIDE_CXX'); - - ld_library_path = fullfile(halide_distrib_path, 'lib'); - - % Build the command to build the generator. - gen_bin = fullfile(temp, [function_name, '.generator']); - build_generator = ... - [halide_cxx, ... - ' -g -Wall -std=c++17 -fno-rtti -I', halide_include, ' ', ... - gengen_filename, ' ', ... - generator_filename, ' ', ... - libhalide, ' ', ... - ' -lz -lpthread -ldl ', ... - '-o ', gen_bin]; - status = system(build_generator); - if status ~= 0 - error('mex_halide:build_failed', 'Generator build failed.'); - return; - end - - % Run the generator to build the object file. - build_object = ... - ['LD_LIBRARY_PATH=', ld_library_path, ' ', ... - 'DYLD_LIBRARY_PATH=', ld_library_path, ' ', ... - gen_bin, ' ', ... - '-f ', function_name, ' ', ... - '-o ', temp, ' ', ... - '-e o,h ', ... - 'target=', target, ' ', ... - generator_args]; - status = system(build_object); - if status ~= 0 - error('mex_halide:build_failed', ['Generator failed to build ' ... - 'pipeline.']); - return; - end - - % Run mex on the resulting object file. - mex(object_file, '-ldl'); - -end diff --git a/tutorial/CMakeLists.txt b/tutorial/CMakeLists.txt index 7d3226d957bb..7c1b1f656132 100644 --- a/tutorial/CMakeLists.txt +++ b/tutorial/CMakeLists.txt @@ -4,10 +4,10 @@ configure_file(images/rgb.png images/rgb.png COPYONLY) function(add_tutorial source_file) set(options WITH_IMAGE_IO WITH_OPENMP) set(oneValueArgs) - set(multiValueArgs SRCS) + set(multiValueArgs SRCS GROUPS) cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - get_filename_component(name "${source_file}" NAME_WE) + cmake_path(GET source_file STEM name) add_executable("${name}" "${source_file}") target_link_libraries("${name}" PRIVATE Halide::Halide Halide::Tools) @@ -17,7 +17,7 @@ function(add_tutorial source_file) set_tests_properties(tutorial_${name} PROPERTIES ENVIRONMENT "HL_TARGET=${Halide_TARGET};HL_JIT_TARGET=${Halide_TARGET}" - LABELS tutorial) + LABELS "tutorial;${args_GROUPS}") if (args_WITH_IMAGE_IO) target_link_libraries(${name} PRIVATE Halide::ImageIO) @@ -40,12 +40,12 @@ endfunction() add_tutorial(lesson_01_basics.cpp) add_tutorial(lesson_02_input_image.cpp WITH_IMAGE_IO) add_tutorial(lesson_03_debugging_1.cpp) -add_tutorial(lesson_04_debugging_2.cpp) -add_tutorial(lesson_05_scheduling_1.cpp) +add_tutorial(lesson_04_debugging_2.cpp GROUPS multithreaded) +add_tutorial(lesson_05_scheduling_1.cpp GROUPS multithreaded) add_tutorial(lesson_06_realizing_over_shifted_domains.cpp) add_tutorial(lesson_07_multi_stage_pipelines.cpp WITH_IMAGE_IO) -add_tutorial(lesson_08_scheduling_2.cpp WITH_IMAGE_IO WITH_OPENMP) -add_tutorial(lesson_09_update_definitions.cpp WITH_IMAGE_IO WITH_OPENMP) +add_tutorial(lesson_08_scheduling_2.cpp WITH_IMAGE_IO WITH_OPENMP GROUPS multithreaded) +add_tutorial(lesson_09_update_definitions.cpp WITH_IMAGE_IO WITH_OPENMP GROUPS multithreaded) if (TARGET_NVPTX) if (TARGET_WEBASSEMBLY AND Halide_TARGET MATCHES "wasm") @@ -56,13 +56,10 @@ if (TARGET_NVPTX) # so we can build the final executable. add_tutorial(lesson_10_aot_compilation_generate.cpp) - # LLVM may leak memory during Halide compilation. If projects are built with address sanitizer enabled, - # this may cause generators to fail, making it hard to use Halide and address sanitizer at the same time. - # To work around this, we execute generators with an environment setting to disable leak checking. set(FILTER_LIB "lesson_10_halide${CMAKE_STATIC_LIBRARY_SUFFIX}") add_custom_command(OUTPUT lesson_10_halide.h "${FILTER_LIB}" DEPENDS lesson_10_aot_compilation_generate - COMMAND ${CMAKE_COMMAND} -E env "ASAN_OPTIONS=detect_leaks=0" $ + COMMAND lesson_10_aot_compilation_generate VERBATIM) add_custom_target(exec_lesson_10_aot_compilation_generate DEPENDS lesson_10_halide.h "${FILTER_LIB}") @@ -85,12 +82,12 @@ if (TARGET_NVPTX) target_include_directories(lesson_10_aot_compilation_run PRIVATE "${CMAKE_CURRENT_BINARY_DIR}") add_test(NAME tutorial_lesson_10_aot_compilation_run COMMAND lesson_10_aot_compilation_run) - set_tests_properties(tutorial_lesson_10_aot_compilation_run PROPERTIES LABELS tutorial) + set_tests_properties(tutorial_lesson_10_aot_compilation_run PROPERTIES LABELS "tutorial;multithreaded") endif () endif () add_tutorial(lesson_11_cross_compilation.cpp) -add_tutorial(lesson_12_using_the_gpu.cpp WITH_IMAGE_IO) +add_tutorial(lesson_12_using_the_gpu.cpp WITH_IMAGE_IO GROUPS multithreaded) add_tutorial(lesson_13_tuples.cpp) add_tutorial(lesson_14_types.cpp) @@ -183,7 +180,7 @@ endif () # Lessons 17 - 20 add_tutorial(lesson_17_predicated_rdom.cpp) -add_tutorial(lesson_18_parallel_associative_reductions.cpp) +add_tutorial(lesson_18_parallel_associative_reductions.cpp GROUPS multithreaded) add_tutorial(lesson_19_wrapper_funcs.cpp) add_tutorial(lesson_20_cloning_funcs.cpp) @@ -194,16 +191,19 @@ if (TARGET Halide::Mullapudi2016) add_halide_library(auto_schedule_false FROM lesson_21_auto_scheduler_generate TARGETS cmake - GENERATOR auto_schedule_gen PARAMS auto_schedule=false) + GENERATOR auto_schedule_gen) add_halide_library(auto_schedule_true FROM lesson_21_auto_scheduler_generate TARGETS cmake AUTOSCHEDULER Halide::Mullapudi2016 - GENERATOR auto_schedule_gen PARAMS machine_params=32,16777216,40) + GENERATOR auto_schedule_gen + PARAMS autoscheduler.parallelism=32 + autoscheduler.last_level_cache_size=16777216 + autoscheduler.balance=40) add_executable(lesson_21_auto_scheduler_run lesson_21_auto_scheduler_run.cpp) target_link_libraries(lesson_21_auto_scheduler_run PRIVATE auto_schedule_false auto_schedule_true Halide::Tools) add_test(NAME tutorial_lesson_21_auto_scheduler_run COMMAND lesson_21_auto_scheduler_run) - set_tests_properties(tutorial_lesson_21_auto_scheduler_run PROPERTIES LABELS tutorial) + set_tests_properties(tutorial_lesson_21_auto_scheduler_run PROPERTIES LABELS "tutorial;multithreaded") endif () diff --git a/tutorial/lesson_14_types.cpp b/tutorial/lesson_14_types.cpp index b43d13bf5e9d..571b03fa834c 100644 --- a/tutorial/lesson_14_types.cpp +++ b/tutorial/lesson_14_types.cpp @@ -79,12 +79,12 @@ int main(int argc, char **argv) { // You can also query any defined Func for the types it produces. Func f1; f1(x) = cast(x); - assert(f1.output_types()[0] == UInt(8)); + assert(f1.types()[0] == UInt(8)); Func f2; f2(x) = {x, sin(x)}; - assert(f2.output_types()[0] == Int(32) && - f2.output_types()[1] == Float(32)); + assert(f2.types()[0] == Int(32) && + f2.types()[1] == Float(32)); } // Type promotion rules. diff --git a/tutorial/lesson_15_generators.cpp b/tutorial/lesson_15_generators.cpp index 969158b4926b..eee00ed05e60 100644 --- a/tutorial/lesson_15_generators.cpp +++ b/tutorial/lesson_15_generators.cpp @@ -155,7 +155,7 @@ class MySecondGenerator : public Halide::Generator { if (rotation != Rotation::None) { rotated .compute_at(output, y) - .vectorize(x, natural_vector_size(rotated.output_types()[0])); + .vectorize(x, natural_vector_size(rotated.types()[0])); } } }; diff --git a/tutorial/lesson_21_auto_scheduler_generate.cpp b/tutorial/lesson_21_auto_scheduler_generate.cpp index 44a1bcac6aea..4258599e8d58 100644 --- a/tutorial/lesson_21_auto_scheduler_generate.cpp +++ b/tutorial/lesson_21_auto_scheduler_generate.cpp @@ -2,7 +2,7 @@ // So far we have written Halide schedules by hand, but it is also possible to // ask Halide to suggest a reasonable schedule. We call this auto-scheduling. -// This lesson demonstrates how to use the auto-scheduler to generate a +// This lesson demonstrates how to use the autoscheduler to generate a // copy-pasteable CPU schedule that can be subsequently improved upon. // On linux or os x, you can compile and run it like so: @@ -11,7 +11,7 @@ // export LD_LIBRARY_PATH= # For linux // export DYLD_LIBRARY_PATH= # For OS X // ./lesson_21_generate -o . -g auto_schedule_gen -f auto_schedule_false -e static_library,h,schedule target=host auto_schedule=false -// ./lesson_21_generate -o . -g auto_schedule_gen -f auto_schedule_true -e static_library,h,schedule -p -S Mullapudi2016 target=host auto_schedule=true machine_params=32,16777216,40 +// ./lesson_21_generate -o . -g auto_schedule_gen -f auto_schedule_true -e static_library,h,schedule -p -S Mullapudi2016 target=host autoscheduler=Mullapudi2016 autoscheduler.parallelism=32 autoscheduler.last_level_cache_size=16777216 autoscheduler.balance=40 // g++ lesson_21_auto_scheduler_run.cpp -std=c++17 -I -I auto_schedule_false.a auto_schedule_true.a -ldl -lpthread -o lesson_21_run // ./lesson_21_run @@ -69,8 +69,8 @@ class AutoScheduled : public Halide::Generator { } void schedule() { - if (auto_schedule) { - // The auto-scheduler requires estimates on all the input/output + if (using_autoscheduler()) { + // The autoscheduler requires estimates on all the input/output // sizes and parameter values in order to compare different // alternatives and decide on a good schedule. @@ -95,31 +95,33 @@ class AutoScheduled : public Halide::Generator { // schedule will be. // To auto-schedule the pipeline, we don't have to do anything else: - // every Generator implicitly has a GeneratorParam named "auto_schedule"; - // if this is set to true, Halide will call auto_schedule() on all of - // our pipeline's outputs automatically. - - // Every Generator also implicitly has a GeneratorParams named "machine_params", - // which allows you to specify characteristics of the machine architecture - // for the auto-scheduler; it's generally specified in your Makefile. + // every Generator implicitly has a GeneratorParam named "auto_scheduler.name"; + // if this is set to the name of the Autoscheduler we want to use, Halide will + // apply it to all of our pipeline's outputs automatically. + + // Every Generator also implicitly has additional, optional GeneratorParams that are + // dependent on the specific Autoscheduler select, which allows you to specify + // characteristics of the machine architecture + // for the autoscheduler; it's generally specified in your Makefile. // If none is specified, the default machine parameters for a generic CPU - // architecture will be used by the auto-scheduler. + // architecture will be used by the autoscheduler. - // Let's see some arbitrary but plausible values for the machine parameters. + // Let's see some arbitrary but plausible values for the machine parameters + // for the Mullapudi2016 Autoscheduler: // - // const int kParallelism = 32; - // const int kLastLevelCacheSize = 16 * 1024 * 1024; - // const int kBalance = 40; - // MachineParams machine_params(kParallelism, kLastLevelCacheSize, kBalance); + // autoscheduler=Mullapudi2016 + // autoscheduler.parallelism=32 + // autoscheduler.last_level_cache_size=16777216 + // autoscheduler.balance=40 // - // The arguments to MachineParams are the maximum level of parallelism - // available, the size of the last-level cache (in KB), and the ratio + // These are the maximum level of parallelism + // available, the size of the last-level cache (in bytes), and the ratio // between the cost of a miss at the last level cache and the cost // of arithmetic on the target architecture, in that order. - // Note that when using the auto-scheduler, no schedule should have - // been applied to the pipeline; otherwise, the auto-scheduler will - // throw an error. The current auto-scheduler cannot handle a + // Note that when using the autoscheduler, no schedule should have + // been applied to the pipeline; otherwise, the autoscheduler will + // throw an error. The current autoscheduler cannot handle a // partially-scheduled pipeline. // If HL_DEBUG_CODEGEN is set to 3 or greater, the schedule will be dumped @@ -131,12 +133,12 @@ class AutoScheduled : public Halide::Generator { // Halide C++ source, which is readily copy-pasteable back into // this very same source file with few modifications. Programmers // can use this as a starting schedule and iteratively improve the - // schedule. Note that the current auto-scheduler is only able to + // schedule. Note that the current autoscheduler is only able to // generate CPU schedules and only does tiling, simple vectorization // and parallelization. It doesn't deal with line buffering, storage // reordering, or factoring reductions. - // At the time of writing, the auto-scheduler will produce the + // At the time of writing, the autoscheduler will produce the // following schedule for the estimates and machine parameters // declared above when run on this pipeline: // @@ -211,7 +213,7 @@ class AutoScheduled : public Halide::Generator { } else { // This is where you would declare the schedule you have written by - // hand or paste the schedule generated by the auto-scheduler. + // hand or paste the schedule generated by the autoscheduler. // We will use a naive schedule here to compare the performance of // the autoschedule with a basic schedule. gray.compute_root(); diff --git a/util/HalideTraceUtils.cpp b/util/HalideTraceUtils.cpp index 6311ff6c2734..7efc3c9ebaef 100644 --- a/util/HalideTraceUtils.cpp +++ b/util/HalideTraceUtils.cpp @@ -40,7 +40,7 @@ bool Packet::read(void *d, size_t size, FILE *fdesc) { perror("Failed during read"); exit(-1); } - return false; //EOF + return false; // EOF } return true;