Skip to content

[torchbench] timm_efficientdet inference fails to run. #6899

@ysiraichi

Description

@ysiraichi

🐛 Bug

timm_efficientdet inference fails to run with both dynamo and non-dynamo configurations. See the error below:

Traceback (most recent call last): File "xla/benchmarks/experiment_runner.py", line 945, in <module> main() File "xla/benchmarks/experiment_runner.py", line 941, in main runner.run() File "xla/benchmarks/experiment_runner.py", line 61, in run self.run_single_config() File "xla/benchmarks/experiment_runner.py", line 256, in run_single_config metrics, last_output = self.run_once_and_gather_metrics( File "xla/benchmarks/experiment_runner.py", line 345, in run_once_and_gather_metrics output, _ = loop(iter_fn=self._default_iter_fn) File "xla/benchmarks/experiment_runner.py", line 302, in loop output, timing, trace = iter_fn(benchmark_experiment, benchmark_model, File "xla/benchmarks/experiment_runner.py", line 218, in _default_iter_fn output = benchmark_model.model_iter_fn( File "xla/benchmarks/benchmark_model.py", line 170, in eval pred = self.module(*inputs) File "torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/lib/python3.8/site-packages/effdet/bench.py", line 110, in forward return _batch_detection( RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): RuntimeError: torch_xla/csrc/tensor_impl.cpp:138 : Check failed: !has_symbolic_sizes_strides_ *** Begin stack trace *** tsl::CurrentStackTrace[abi:cxx11]() torch_xla::XLATensorImpl::sizes_custom() const at::FunctionalTensorWrapper::sizes_custom() const c10::TensorType::create(at::Tensor const&) torch::jit::tensorTypeInCurrentExecutionContext(at::Tensor const&) _PyObject_MakeTpCall PyVectorcall_Call _PyObject_MakeTpCall _PyEval_EvalFrameDefault _PyEval_EvalCodeWithName _PyFunction_Vectorcall PyVectorcall_Call _PyEval_EvalFrameDefault _PyEval_EvalCodeWithName _PyFunction_Vectorcall PyVectorcall_Call _PyEval_EvalFrameDefault _PyEval_EvalCodeWithName _PyFunction_Vectorcall _PyObject_FastCallDict _PyObject_Call_Prepend PyObject_Call _PyEval_EvalFrameDefault _PyEval_EvalCodeWithName _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalCodeWithName _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyEval_EvalCodeWithName _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalCodeWithName PyEval_EvalCodeEx PyEval_EvalCode PyRun_SimpleFileExFlags Py_RunMain Py_BytesMain __libc_start_main _start *** End stack trace *** Cannot call sizes_custom() on an XLA tensor with symbolic sizes/strides

Affected Configurations

  • Inference+Dynamo
  • Inference+NonDynamo

Environment

  • Reproducible on XLA backend [CPU/TPU]: CUDA
  • torch_xla version: 5c48be1

cc @miladm @JackCaoG @vanbasten23 @cota @golechwierowicz @frgossen @zpcore

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions