Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1985,15 +1985,17 @@ void ScatterInferMeta(const MetaTensor& x,
"Input(Updates)'s shape is %d.",
ref_dims.size(),
updates_dims.size()));
PADDLE_ENFORCE_LE(
index_dims[0],
updates_dims[0],
common::errors::InvalidArgument(
"The first dimension size of Input(Index) should be no greater "
"than Input(Updates), but received first dimension size of "
"Input(Index) is %d, Input(Updates) is %d.",
index_dims[0],
updates_dims[0]));
if (index_dims[0] != -1 && updates_dims[0] != -1) {
PADDLE_ENFORCE_LE(
index_dims[0],
updates_dims[0],
common::errors::InvalidArgument(
"The first dimension size of Input(Index) should be no greater "
"than Input(Updates), but received first dimension size of "
"Input(Index) is %d, Input(Updates) is %d.",
index_dims[0],
updates_dims[0]));
}
} else {
PADDLE_ENFORCE_EQ(
(ref_dims.size() - 1 == updates_dims.size()),
Expand Down
35 changes: 5 additions & 30 deletions python/paddle/jit/sot/infer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,41 +38,18 @@
from paddle.static import InputSpec
from paddle.utils import flatten, is_sequence

from .symbolic_shape import SymbolicInt
from .utils import (
Cache,
Singleton,
map_if_extend,
meta_str,
update_list_inplace,
)

DynamicSymbolT = TypeVar("DynamicSymbolT")
SOT_INFER_META_INNER_VAR = "___SOT_INFER_META_INNER_VAR"


class SymbolicValue(metaclass=Singleton):
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"

def get_static_type(self) -> type:
raise NotImplementedError("get_py_type is not implemented.")


class SymbolicBool(SymbolicValue):
def get_static_type(self) -> type[bool]:
return bool


class SymbolicInt(SymbolicValue):
def get_static_type(self) -> type[int]:
return int


class SymbolicFloat(SymbolicValue):
def get_static_type(self) -> type[float]:
return float


class DistInfo:
def __init__(self, mesh=None, dims_mapping=None, local_shape=None):
self.mesh = mesh
Expand Down Expand Up @@ -152,16 +129,14 @@ def shape_with_special_symbol(
]

def with_dynamic_axes(self, name: str, dynamic_axes: list[int]) -> MetaInfo:
# NOTE(SigureMo): Make sure create a new shape list with dynamic axes.
# We will create a new shape list variable lazily in the future.
shape = [
SymbolicInt() if i in dynamic_axes else dim
SymbolicInt(dim) if i in dynamic_axes else dim
for i, dim in enumerate(self.shape)
]
# NOTE(SigureMo): Ensure output meta.shape is same list object as
# self.shape to avoid create two different data proxy for tensor.shape.
# It will caused create a new SymbolicVariable when it's a dynamic dim.
self.shape = update_list_inplace(self.shape, shape)
return MetaInfo(
self.shape,
shape,
self.dtype,
self.stop_gradient,
self.name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

# This file stores the customized function that will be called by the dispatch mechanism.

from __future__ import annotations

from ...utils import BreakGraphError, BreakGraphReasonBase, FallbackError


Expand Down
52 changes: 35 additions & 17 deletions python/paddle/jit/sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ...profiler import EventGuard, event_register
from ...symbolic.statement_ir import Reference, StatementIR, Symbol
from ...symbolic.symbolic_context import SymbolicTraceContext
from ...symbolic_shape import SYMBOLIC_BINARY_OPS, SYMBOLIC_UNARY_OPS
from ...utils import (
ENV_SOT_ALLOW_DYNAMIC_SHAPE,
ENV_SOT_ENABLE_GUARD_TREE,
Expand Down Expand Up @@ -527,46 +528,45 @@ def message_handler(*args, **kwargs):
**kwargs,
)

def call_tensor_method(
self, method_name: str, *args: VariableBase, **kwargs
def call_symbolic_api(
self,
op: Callable[..., Any],
*args: VariableBase,
**kwargs: VariableBase,
):
"""
call tensor method, start symbolic trace.

Args:
method_name: tensor method name
"""
assert op in SYMBOLIC_UNARY_OPS + SYMBOLIC_BINARY_OPS
log(3, f"call symbolic api : {op.__name__}", "\n")

def message_handler(*args, **kwargs):
return f"Call tensor_method error: Tensor.{method_name}, may be not a valid operator api?"
return f"Call operator error: {op.__name__}"

return inner_error_default_handler(self.symbolic_call, message_handler)(
InferMetaCache(),
self.sir_ctx.call_METHOD,
method_name,
False,
self.sir_ctx.call_API,
op,
True,
*args,
**kwargs,
)

def call_symbolic_method(
def call_tensor_method(
self, method_name: str, *args: VariableBase, **kwargs
):
"""
call symbolic method, start symbolic trace.
call tensor method, start symbolic trace.

Args:
method_name: symbolic method name
method_name: tensor method name
"""

def message_handler(*args, **kwargs):
return f"Call symbolic_method error: Symbolic.{method_name}, may be not a valid operator api?"
return f"Call tensor_method error: Tensor.{method_name}, may be not a valid operator api?"

return inner_error_default_handler(self.symbolic_call, message_handler)(
InferMetaCache(),
self.sir_ctx.call_METHOD,
method_name,
True,
False,
*args,
**kwargs,
)
Expand Down Expand Up @@ -799,6 +799,24 @@ def try_infer_meta_fn(args, kwargs) -> Any:

return VariableFactory.from_value(outputs, self, tracker)

def add_alias(
self,
src: TensorVariable | SymbolicVariable,
dst: TensorVariable | SymbolicVariable,
):
"""
Add an alias like `dst = src`
"""
alias_fn = lambda x: x
alias_fn.__name__ = "__sir_alias__"
inputs_arg_pack = ([src], {})
self.sir_ctx.call_API(
alias_fn,
convert_to_symbol(inputs_arg_pack),
convert_to_symbol(dst),
[],
)

@staticmethod
def get_opcode_executor_stack():
# NOTE: only for debug.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,11 @@ def inner(self: OpcodeExecutorBase, instr: Instruction):
)(res)

assert isinstance(res, (ConstantVariable, SymbolicVariable))
# NOTE(SigureMo): force to constant to trigger fallback to static dim
# to align with old behavior. In next PR we will support guard value
# with constraint.
if isinstance(res, SymbolicVariable):
res = res.to_constant()
is_jump = res.get_py_value()
assert isinstance(is_jump, bool)
if is_jump:
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/jit/sot/opcode_translator/executor/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
if TYPE_CHECKING:
from collections.abc import Sequence

from ...utils.magic_methods import BinaryOp, UnaryOp
from .pycode_generator import PyCodeGen
from .variables import VariableBase

Expand Down Expand Up @@ -131,9 +132,9 @@ class SymbolicOperationTracker(Tracker):
inputs (list[VariableBase]): The input variables associated with the generated variables.
"""

def __init__(self, inputs: Sequence[VariableBase], method_name: str):
def __init__(self, inputs: Sequence[VariableBase], op: UnaryOp | BinaryOp):
super().__init__(inputs)
self.method_name = method_name
self.op = op

def gen_instructions(self, codegen: PyCodeGen):
raise InnerError("SymbolicOperationTracker has no instructions")
Expand Down
106 changes: 53 additions & 53 deletions python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@

import paddle

from ...symbolic_shape import (
SYMBOLIC_BINARY_OPS,
SYMBOLIC_UNARY_OPS,
symbolic_not,
symbolic_to_bool,
)
from ...utils import (
BreakGraphError,
BuiltinFunctionBreak,
Expand All @@ -39,7 +45,6 @@
UNARY_OPS,
magic_method_builtin_dispatch,
)
from ...utils.paddle_api_config import get_tensor_methods
from .dispatch_functions import (
create_raise_break_graph_handler,
generator_send,
Expand Down Expand Up @@ -306,14 +311,6 @@ def dispatch_dict_fromkeys(
lambda var: var.keys(),
)

Dispatcher.register(
operator.not_,
("VariableBase",),
lambda x: ConstantVariable(
not x.get_py_value(allow_tensor=False), x.graph, DummyTracker([x])
),
)

Dispatcher.register(
dict.values,
("DictVariable",),
Expand Down Expand Up @@ -1047,7 +1044,7 @@ def is_not_func(var: VariableBase, other: VariableBase):
binary_fn,
),
)
# Tensor and Symbolic
# Tensor
fallback_tensor_unary_method = {
int,
bool,
Expand Down Expand Up @@ -1087,16 +1084,6 @@ def is_not_func(var: VariableBase, other: VariableBase):
magic_method.name,
),
)
Dispatcher.register(
unary_fn,
("SymbolicVariable",),
partial(
lambda magic_name, var: var.graph.call_symbolic_method(
magic_name, var
),
magic_method.name,
),
)
for binary_fn in BINARY_OPS:
for magic_method in magic_method_builtin_dispatch(binary_fn):
# skip all inplace magic method name, we will dispatch it to non-inplace
Expand Down Expand Up @@ -1152,40 +1139,43 @@ def tensor_mod_dispatcher(
),
)

for binary_fn in BINARY_OPS:
for magic_method in magic_method_builtin_dispatch(binary_fn):
if magic_method.name not in get_tensor_methods():
continue
# skip all inplace magic method name, we will dispatch it to non-inplace
# magic methods
if magic_method.is_inplace:
continue
# Symbolic
for unary_fn in SYMBOLIC_UNARY_OPS:
Dispatcher.register(
unary_fn,
("SymbolicVariable",),
partial(
lambda fn, var: var.graph.call_symbolic_api(fn, var),
unary_fn,
),
)
for binary_fn in SYMBOLIC_BINARY_OPS:
Dispatcher.register(
binary_fn,
("SymbolicVariable", "SymbolicVariable | ConstantVariable"),
partial(
lambda fn, var, other: var.graph.call_symbolic_api(fn, var, other),
binary_fn,
),
)
Dispatcher.register(
binary_fn,
("ConstantVariable", "SymbolicVariable"),
partial(
lambda fn, var, other: var.graph.call_symbolic_api(fn, var, other),
binary_fn,
),
)

if not magic_method.is_reverse:
Dispatcher.register(
binary_fn,
(
"SymbolicVariable",
"ConstantVariable | SymbolicVariable",
),
partial(
lambda magic_name, var, other: var.graph.call_symbolic_method(
magic_name, var, other
),
magic_method.name,
),
)
else:
Dispatcher.register(
binary_fn,
("ConstantVariable", "SymbolicVariable"),
partial(
lambda reverse_magic_name, var, other: var.graph.call_symbolic_method(
reverse_magic_name, other, var
),
magic_method.name,
),
)

@Dispatcher.register_decorator(bool)
def dispatch_symbolic_bool(var: SymbolicVariable):
return BuiltinVariable(symbolic_to_bool, var.graph, DanglingTracker())(var)


@Dispatcher.register_decorator(operator.not_)
def dispatch_symbolic_not(var: SymbolicVariable):
return BuiltinVariable(symbolic_not, var.graph, DanglingTracker())(var)


# Register dispatch for DataVariable: directly call and return a wrapped variable.
Expand Down Expand Up @@ -1600,3 +1590,13 @@ def dispatch_all(var: ContainerVariable | IterVariable):
("PlaceVariable",),
lambda var: var.get_device_type(),
)

# not for all variable
# TODO(SigureMo): Optimize this dispatch
Dispatcher.register(
operator.not_,
("VariableBase",),
lambda x: ConstantVariable(
not x.get_py_value(allow_tensor=False), x.graph, DummyTracker([x])
),
)
Loading