Skip to content

Commit 949dc05

Browse files
justinchubypytorchmergebot
authored andcommitted
Add methods to torch._C pyi (pytorch#78757)
Add python bound methods in `Block`, `Node`, `Graph` and `Value` to `torch._C.__init__.pyi` to enable proper type hints in editors. Pull Request resolved: pytorch#78757 Approved by: https://github.com/eellison, https://github.com/BowenBao
1 parent e675dba commit 949dc05

File tree

1 file changed

+46
-6
lines changed

1 file changed

+46
-6
lines changed

torch/_C/__init__.pyi.in

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -457,40 +457,80 @@ class Value:
457457
def type(self)-> JitType: ...
458458
def setType(self, t: JitType) -> Value: ...
459459
def debugName(self) -> str: ...
460+
def requires_grad(self) -> _bool: ...
460461
...
461462

462463
# Defined in torch/csrc/jit/ir/ir.h
463464
class Block:
464465
def inputs(self) -> List[Value]: ...
465466
def outputs(self) -> List[Value]: ...
467+
def nodes(self) -> Iterator[Node]: ...
468+
def paramNode(self) -> Node: ...
469+
def returnNode(self) -> Node: ...
470+
def owningNode(self) -> Node: ...
471+
def registerOutput(self, n: Value) -> _int: ...
466472
...
467473

468474
# Defined in torch/csrc/jit/ir/ir.h
469475
class Node:
476+
def __getitem__(self, key: str) -> Any: ...
470477
def schema(self) -> str: ...
478+
def input(self) -> Value: ...
479+
def inputs(self) -> List[Value]: ...
471480
def output(self) -> Value: ...
472481
def outputs(self) -> List[Value]: ...
473482
def outputsSize(self) -> _int: ...
474483
def blocks(self) -> List[Block]: ...
484+
def addBlock(self) -> Block: ...
475485
def mustBeNone(self) -> _bool: ...
476-
def kindOf(self, str) -> str: ...
477-
def __getitem__(self, key: str) -> Any: ...
478-
def namedInput(self, str) -> Value: ...
479-
def sourceRange(self) -> SourceRange: ...
480486
def kind(self) -> str: ...
487+
def kindOf(self, name: str) -> str: ...
488+
def addInput(self, name: str) -> Value: ...
489+
def replaceInput(self, i: _int, newValue: Value) -> Value: ...
490+
def replaceInputWith(self, from_: Value, to: Value) -> None: ...
491+
def replaceAllUsesWith(self, n: Node) -> None: ...
492+
def insertBefore(self, n: Node) -> Node: ...
493+
def insertAfter(self, n: Node) -> Node: ...
494+
def isBefore(self, n: Node) -> _bool: ...
495+
def isAfter(self, n: Node) -> _bool: ...
496+
def moveBefore(self, n: Node) -> None: ...
497+
def moveAfter(self, n: Node) -> None: ...
498+
def removeInput(self, i: _int) -> None: ...
499+
def removeAllInputs(self, i: _int) -> None: ...
500+
def hasUses(self) -> _bool: ...
501+
def eraseOutput(self, i: _int) -> None: ...
502+
def addOutput(self) -> Value: ...
503+
def scopeName(self) -> str: ...
504+
def isNondeterministic(self) -> _bool: ...
505+
def copyAttributes(self, rhs: Node) -> Node: ...
506+
def hasAttributes(self, name: str) -> _bool: ...
507+
def namedInput(self, name: str) -> Value: ...
508+
def sourceRange(self) -> SourceRange: ...
481509
...
482510

483511
# Defined in torch/torch/csrc/jit/ir/ir.h
484512
class Graph:
513+
def inputs(self) -> List[Value]: ...
514+
def outputs(self) -> List[Value]: ...
515+
def nodes(self) -> Iterator[Node]: ...
516+
def param_node(self) -> Node: ...
517+
def return_node(self) -> Node: ...
518+
def addInput(self, name: str) -> Value: ...
485519
def eraseInput(self, i: _int) -> None: ...
520+
def registerOutput(self, n: Value) -> _int: ...
521+
def eraseOutput(self, i: _int) -> None: ...
522+
def create(self, name: str, args, num_outputs: _int) -> Node: ...
523+
def appendNode(self, n: Node) -> Node: ...
524+
def prependNode(self, n: Node) -> Node: ...
525+
def insertNode(self, n: Node) -> Node: ...
526+
def block(self) -> Block: ...
527+
def lint(self) -> None: ...
486528
def alias_db(self) -> AliasDb: ...
487-
def inputs(self) -> List[Value]: ...
488529
def setInsertPoint(self, n: Union[Block, Node]) -> None: ...
489530
def insert_point_guard(self, n: Union[Block, Node]) -> _InsertPoint: ...
490531
def insertPoint(self) -> Node: ...
491532
def insertGraph(self, callee: Graph, inputs: List[Value]) -> List[Value]: ...
492533
def makeMultiOutputIntoTuple(self) -> None: ...
493-
def nodes(self) -> Iterator: ...
494534
...
495535

496536

0 commit comments

Comments
 (0)