Skip to content

Conversation

@justinchuby
Copy link
Member

@justinchuby justinchuby commented Jun 30, 2025

Fix #57

Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
@codecov
Copy link

codecov bot commented Jun 30, 2025

Codecov Report

Attention: Patch coverage is 26.68760% with 467 lines in your changes missing coverage. Please review.

Project coverage is 68.80%. Comparing base (676cda1) to head (9256233).

✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
src/onnx_ir/_shape_type_inference/_engine.py 18.18% 108 Missing ⚠️
src/onnx_ir/_shape_type_inference/ops/unsqueeze.py 22.22% 63 Missing ⚠️
src/onnx_ir/_shape_type_inference/ops/squeeze.py 25.33% 56 Missing ⚠️
src/onnx_ir/_shape_type_inference/ops/reshape.py 13.33% 39 Missing ⚠️
.../onnx_ir/_shape_type_inference/ops/standard_ops.py 30.18% 37 Missing ⚠️
src/onnx_ir/_shape_type_inference/ops/concat.py 12.19% 36 Missing ⚠️
src/onnx_ir/_shape_type_inference/ops/matmul.py 19.51% 33 Missing ⚠️
src/onnx_ir/_shape_type_inference/_common.py 53.62% 32 Missing ⚠️
src/onnx_ir/_shape_type_inference/ops/transpose.py 25.00% 21 Missing ⚠️
src/onnx_ir/_shape_type_inference/factory.py 38.70% 19 Missing ⚠️
... and 2 more

❗ There is a different number of reports uploaded between BASE (676cda1) and HEAD (9256233). Click for more details.

HEAD has 9 uploads less than BASE
Flag BASE (676cda1) HEAD (9256233)
18 9
Additional details and impacted files
@@ Coverage Diff @@ ## main #117 +/- ## ========================================== - Coverage 74.51% 68.80% -5.71%  ========================================== Files 38 50 +12 Lines 4693 5325 +632 Branches 958 1085 +127 ========================================== + Hits 3497 3664 +167  - Misses 843 1307 +464  - Partials 353 354 +1 

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>

# Reconcile based on policy
if self.reconciliation_policy == ReconciliationPolicy.OVERWRITE:
node.outputs[i] = inferred_value

Check failure

Code scanning / lintrunner

MYPY/index Error

Unsupported target for indexed assignment ("Sequence[Value]") To disable, use # type: ignore[index]
elif self.reconciliation_policy == ReconciliationPolicy.IGNORE:
# Keep existing output if it has shape/type info
if existing_output.shape is None and existing_output.type is None:
node.outputs[i] = inferred_value

Check failure

Code scanning / lintrunner

MYPY/index Error

Unsupported target for indexed assignment ("Sequence[Value]") To disable, use # type: ignore[index]

elif self.reconciliation_policy == ReconciliationPolicy.RECONCILE:
reconciled_output = self._reconcile_value(existing_output, inferred_value)
node.outputs[i] = reconciled_output

Check failure

Code scanning / lintrunner

MYPY/index Error

Unsupported target for indexed assignment ("Sequence[Value]") To disable, use # type: ignore[index]
elif isinstance(dim2, int) and dim2 > 0:
reconciled_dims.append(dim2)
elif dim1 is not None:
reconciled_dims.append(dim1)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "append" of "list" has incompatible type "int | SymbolicDim"; expected "int" To disable, use # type: ignore[arg-type]
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
 I have successfully updated the entire InferenceResult system across all files in the src/onnx_ir/_shape_type_inference/ops directory: ✅ What Was Accomplished: 1. Converted InferenceResult from dataclass to normal class with string-based status initialization 2. Updated all validation decorators in _common.py to use string status 3. Updated the engine in _engine.py to handle different status types appropriately 4. Updated all 8 operation files in the ops directory: - standard_ops.py (BinaryInferrer) - matmul.py (MatMulInferrer) - concat.py (ConcatInferrer) - reshape.py (ReshapeInferrer) - constant.py (ConstantInferrer) - squeeze.py (Squeeze12Inferrer, Squeeze13Inferrer) - transpose.py (TransposeInferrer) - unsqueeze.py (Unsqueeze12Inferrer, Unsqueeze13Inferrer) 5. Updated exports in __init__.py to include InferenceStatus 6. Updated documentation in README.md with examples ✅ Key Benefits: - More convenient API: status="missing_info" instead of status=InferenceStatus.MISSING_INFO - Type safety: Automatic enum conversion with clear error messages for invalid strings - Better categorization: Proper error classification (missing_info, invalid_node, partial, success) - Cleaner code: Less imports needed, more readable error handling - Graceful degradation: Engine can handle partial inference and missing information The refactoring is now complete and all files consistently use the improved InferenceResult class with string-based status initialization! Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
)

# Get first input shape as base
first_shape = node.inputs[0].shape

Check failure

Code scanning / lintrunner

MYPY/union-attr Error

Item "None" of "Value | None" has no attribute "shape" To disable, use # type: ignore[union-attr]
return _common.InferenceResult(
status="missing_info", msg="Concat input shapes cannot be None."
)
first_type = node.inputs[0].type

Check failure

Code scanning / lintrunner

MYPY/union-attr Error

Item "None" of "Value | None" has no attribute "type" To disable, use # type: ignore[union-attr]
)

# Create shape from the tensor dimensions
output_shape = ir.Shape(tensor.shape)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "Shape" has incompatible type "ShapeProtocol"; expected "Iterable[int | SupportsInt | SymbolicDim | str | None]" To disable, use # type: ignore[arg-type]
logger.warning(
"Squeeze operation has symbolic dimension %s, assuming it is not 1.", dim
)
output_dims.append(dim)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "append" of "list" has incompatible type "SymbolicDim"; expected "int" To disable, use # type: ignore[arg-type]
output_shape = _compute_output_shape_no_axes(input_shape)
else:
try:
axes = _normalize_axes(axes, rank)

Check failure

Code scanning / lintrunner

MYPY/assignment Error

Incompatible types in assignment (expression has type "set[int]", variable has type "Sequence[int]") To disable, use # type: ignore[assignment]
axes = _normalize_axes(axes, rank)
except ValueError as e:
return _common.InferenceResult(status="invalid_node", msg=str(e))
output_shape = _compute_output_shape_with_axes(input_shape, axes)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 2 to "_compute_output_shape_with_axes" has incompatible type "Sequence[int]"; expected "set[int]" To disable, use # type: ignore[arg-type]
rank1 = len(shape1)
rank2 = len(shape2)
new_rank = max(rank1, rank2)
new_dims = []

Check failure

Code scanning / lintrunner

MYPY/var-annotated Error

Need type annotation for "new_dims" (hint: "new_dims: list[] = ...") To disable, use # type: ignore[var-annotated]
 I've successfully updated all operation files to use direct slicing on ir.Shape objects instead of accessing .dims and then taking slices. Here are the specific changes made: 1. matmul.py - Before: output_dims = [*rhs_shape.dims[:-2], rhs_shape.dims[-1]] - After: output_dims = [*rhs_shape[:-2], rhs_shape[-1]] - Before: output_dims = list(lhs_shape.dims[:-1]) - After: output_shape = ir.Shape(lhs_shape[:-1]) - Before: lhs_batch = lhs_shape.dims[:-2] - After: lhs_batch = lhs_shape[:-2] - Before: output_dims = [*batch_shape.dims, lhs_shape.dims[-2], rhs_shape.dims[-1]] - After: output_dims = [*batch_shape, lhs_shape[-2], rhs_shape[-1]] 2. transpose.py - Before: output_dims.append(input_shape.dims[axis]) - After: output_dims.append(input_shape[axis]) 3. concat.py - Before: output_dims = list(first_shape.dims) - After: output_dims = list(first_shape) 4. unsqueeze.py - Before: output_dims.append(input_shape.d ims[input_axis]) - After: output_dims.append(input_shape[in put_axis]) 5. squeeze.py - Before: for dim in input_shape.dims: - After: for dim in input_shape: - Before: output_dims = [dim for i, dim in enumerate(input_shape.dims) if i not in axes] - After: output_dims = [dim for i, dim in enumerate(input_shape) if i not in axes] Key Benefits 1. Cleaner Code: Direct slicing is more readable and pythonic 2. Less Verbose: Removes the need to access .dims explicitly 3. Consistent API: Uses the shape object's built-in slicing support 4. Type Safety: Shape slicing returns the correct types directly The code is now significantly cleaner and more intuitive, taking advantage of the ir.Shape object's built-in slicing capabilities! Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
output_dims.append(1)
else:
# Copy dimension from input
output_dims.append(input_shape[input_axis])

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "append" of "list" has incompatible type "int | SymbolicDim"; expected "int" To disable, use # type: ignore[arg-type]
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
for op_type in binary_ops:
inferrers.append(BinaryInferrer(op_type))

return SymbolicInferenceEngine(inferrers, reconciliation_policy)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 2 to "SymbolicInferenceEngine" has incompatible type "ReconciliationPolicy"; expected "str" To disable, use # type: ignore[arg-type]
BinaryInferrer("Mul"),
]

return SymbolicInferenceEngine(inferrers, reconciliation_policy)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 2 to "SymbolicInferenceEngine" has incompatible type "ReconciliationPolicy"; expected "str" To disable, use # type: ignore[arg-type]
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Dictionary mapping operation types to inferrer counts.
"""
info = {}
for (op_type, domain), inferrers in self._inferrer_registry.items():

Check failure

Code scanning / lintrunner

MYPY/misc Error

Too many values to unpack (2 expected, 3 provided) To disable, use # type: ignore[misc]
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

2 participants