Skip to content

[Bug] Segfault when applying Parallel during TIR schedule rewriting #18424

@Cookiee235

Description

@Cookiee235

When running meta_schedule.tune_tir on a valid TIR module involving multi-dimensional access patterns, TVM crashes during the schedule rewriting phase (RewriteParallelVectorizeUnroll).

Actual behavior

Traceback (most recent call last):
  File "/share_container/LLMFuzz/TirFuzz/bugs/10-24_20-21/topi.gather_0_M1.py", line 32, in <module>
    database = ms.tir_integration.tune_tir(mod=tir_mod, target='llvm --num-cores=32', work_dir='./tune_tmp', max_trials_global=1, num_trials_per_iter=1)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "test", line 146, in tune_tir
    return tune_tasks(
           ^^^^^^^^^^^
  File "/software/tvm-latest/python/tvm/meta_schedule/tune.py", line 122, in tune_tasks
    task_scheduler.tune(
  File "/software/tvm-latest/python/tvm/meta_schedule/task_scheduler/task_scheduler.py", line 132, in tune
    _ffi_api.TaskSchedulerTune(  # type: ignore # pylint: disable=no-member
  File "python/tvm_ffi/cython/function.pxi", line 758, in core.Function.__call__
  File "<unknown>", line 0, in tvm::meta_schedule::GradientBasedNode::Tune(tvm::ffi::Array<tvm::meta_schedule::TuneContext, void>, tvm::ffi::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::ffi::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::ffi::Optional<tvm::meta_schedule::Database, void>, tvm::ffi::Optional<tvm::meta_schedule::CostModel, void>)
  File "<unknown>", line 0, in tvm::meta_schedule::TaskSchedulerNode::Tune(tvm::ffi::Array<tvm::meta_schedule::TuneContext, void>, tvm::ffi::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::ffi::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::ffi::Optional<tvm::meta_schedule::Database, void>, tvm::ffi::Optional<tvm::meta_schedule::CostModel, void>)
  File "<unknown>", line 0, in tvm::meta_schedule::EvolutionarySearchNode::GenerateMeasureCandidates()
  File "<unknown>", line 0, in tvm::meta_schedule::EvolutionarySearchNode::State::GenerateMeasureCandidates()
  File "<unknown>", line 0, in tvm::meta_schedule::EvolutionarySearchNode::State::SampleInitPopulation(int)
  File "<unknown>", line 0, in tvm::support::parallel_for_dynamic(int, int, int, std::function<void (int, int)> const&) [clone .cold]
  File "<unknown>", line 0, in tvm::runtime::detail::LogFatal::~LogFatal() [clone .constprop.0]
  File "<unknown>", line 0, in tvm::runtime::detail::LogFatal::Entry::Finalize()
RuntimeError: parallel_for_dynamic error with Traceback (most recent call last):
  File "<unknown>", line 0, in tvm::meta_schedule::GradientBasedNode::Tune(tvm::ffi::Array<tvm::meta_schedule::TuneContext, void>, tvm::ffi::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::ffi::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::ffi::Optional<tvm::meta_schedule::Database, void>, tvm::ffi::Optional<tvm::meta_schedule::CostModel, void>)
  File "<unknown>", line 0, in tvm::meta_schedule::TaskSchedulerNode::Tune(tvm::ffi::Array<tvm::meta_schedule::TuneContext, void>, tvm::ffi::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::ffi::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::ffi::Optional<tvm::meta_schedule::Database, void>, tvm::ffi::Optional<tvm::meta_schedule::CostModel, void>)
  File "<unknown>", line 0, in tvm::meta_schedule::EvolutionarySearchNode::GenerateMeasureCandidates()
  File "<unknown>", line 0, in tvm::meta_schedule::EvolutionarySearchNode::State::GenerateMeasureCandidates()
  File "<unknown>", line 0, in tvm::meta_schedule::EvolutionarySearchNode::State::SampleInitPopulation(int)
  File "<unknown>", line 0, in tvm::support::parallel_for_dynamic(int, int, int, std::function<void (int, int)> const&)
  File "<unknown>", line 0, in tvm::meta_schedule::EvolutionarySearchNode::State::SampleInitPopulation(int)::{lambda(int, int)#1}::operator()(int, int) const
  File "<unknown>", line 0, in tvm::meta_schedule::ThreadedTraceApply::Apply(tvm::IRModule const&, tvm::tir::Trace const&, long*)
  File "<unknown>", line 0, in tvm::meta_schedule::RewriteParallelVectorizeUnrollNode::Apply(tvm::tir::Schedule const&)
  File "<unknown>", line 0, in tvm::tir::RewriteFuseSplitParallelVectorize(tvm::tir::Schedule const&, tvm::ffi::Array<tvm::tir::LoopRV, void>*, int)
  File "<unknown>", line 0, in tvm::tir::TracedScheduleNode::Parallel(tvm::tir::LoopRV const&)
  File "/software/tvm-latest/src/tir/schedule/concrete_schedule.cc", line 630, in virtual void tvm::tir::ConcreteScheduleNode::Parallel(const tvm::tir::LoopRV&)
ScheduleError: (not rendered)

Environment

tvm: 0.23.dev0

Steps to reproduce

import tvm
from tvm import te, topi, tir
from tvm import meta_schedule as ms

tir_str = """# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(data: T.Buffer((4, 6, 8), "float32"), indices: T.Buffer((2, 6, 8), "int32"), T_gather: T.Buffer((2, 6, 8), "float32")):
        T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}), "tir.noalias": True})
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(2, 6, 8):
            with T.block("T_gather"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(data[indices[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2], indices[v_ax0, v_ax1, v_ax2])
                T.writes(T_gather[v_ax0, v_ax1, v_ax2])
                T_gather[v_ax0, v_ax1, v_ax2] = data[indices[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2]
        
        # Add additional multi-dimensional access to trigger StorageFlatten
        for ax0, ax1, ax2 in T.grid(4, 6, 8):
            with T.block("additional_access"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(data[v_ax0, v_ax1, v_ax2])
                T.writes(data[v_ax0, v_ax1, v_ax2])
                data[v_ax0, v_ax1, v_ax2] = data[v_ax0, v_ax1, v_ax2] + T.float32(1.0)
"""

tir_mod = tvm.script.from_source(tir_str)
tir_mod.show()
database = ms.tir_integration.tune_tir(mod=tir_mod, target='llvm --num-cores=32', work_dir='./tune_tmp', max_trials_global=1, num_trials_per_iter=1)

Triage

  • needs-triage
  • meta-tune

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions