Skip to content
Prev Previous commit
Next Next commit
Add RuleOptionsType
  • Loading branch information
chrisjsewell committed May 31, 2023
commit 755b0a08c366dff2fd0f34cf79c0687ff404f8fc
48 changes: 37 additions & 11 deletions markdown_it/ruler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Ruler

from collections.abc import Callable, Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypedDict

from markdown_it._compat import DATACLASS_KWARGS

Expand Down Expand Up @@ -51,7 +51,11 @@ def src(self, value: str) -> None:
# arguments may or may not exist, based on the rule's type (block,
# core, inline). Return type is either `None` or `bool` based on the
# rule's type.
RuleFunc = Callable
RuleFunc = Callable # type: ignore


class RuleOptionsType(TypedDict, total=False):
alt: list[str]


@dataclass(**DATACLASS_KWARGS)
Expand Down Expand Up @@ -97,7 +101,9 @@ def __compile__(self) -> None:
continue
self.__cache__[chain].append(rule.fn)

def at(self, ruleName: str, fn: RuleFunc, options=None):
def at(
self, ruleName: str, fn: RuleFunc, options: RuleOptionsType | None = None
) -> None:
"""Replace rule by name with new function & options.

:param ruleName: rule name to replace.
Expand All @@ -113,7 +119,13 @@ def at(self, ruleName: str, fn: RuleFunc, options=None):
self.__rules__[index].alt = options.get("alt", [])
self.__cache__ = None

def before(self, beforeName: str, ruleName: str, fn: RuleFunc, options=None):
def before(
self,
beforeName: str,
ruleName: str,
fn: RuleFunc,
options: RuleOptionsType | None = None,
) -> None:
"""Add new rule to chain before one with given name.

:param beforeName: new rule will be added before this one.
Expand All @@ -129,7 +141,13 @@ def before(self, beforeName: str, ruleName: str, fn: RuleFunc, options=None):
self.__rules__.insert(index, Rule(ruleName, True, fn, options.get("alt", [])))
self.__cache__ = None

def after(self, afterName: str, ruleName: str, fn: RuleFunc, options=None):
def after(
self,
afterName: str,
ruleName: str,
fn: RuleFunc,
options: RuleOptionsType | None = None,
) -> None:
"""Add new rule to chain after one with given name.

:param afterName: new rule will be added after this one.
Expand All @@ -147,7 +165,9 @@ def after(self, afterName: str, ruleName: str, fn: RuleFunc, options=None):
)
self.__cache__ = None

def push(self, ruleName: str, fn: RuleFunc, options=None):
def push(
self, ruleName: str, fn: RuleFunc, options: RuleOptionsType | None = None
) -> None:
"""Push new rule to the end of chain.

:param ruleName: new rule will be added to the end of chain.
Expand All @@ -158,7 +178,9 @@ def push(self, ruleName: str, fn: RuleFunc, options=None):
self.__rules__.append(Rule(ruleName, True, fn, (options or {}).get("alt", [])))
self.__cache__ = None

def enable(self, names: str | Iterable[str], ignoreInvalid: bool = False):
def enable(
self, names: str | Iterable[str], ignoreInvalid: bool = False
) -> list[str]:
"""Enable rules with given names.

:param names: name or list of rule names to enable.
Expand All @@ -168,7 +190,7 @@ def enable(self, names: str | Iterable[str], ignoreInvalid: bool = False):
"""
if isinstance(names, str):
names = [names]
result = []
result: list[str] = []
for name in names:
idx = self.__find__(name)
if (idx < 0) and ignoreInvalid:
Expand All @@ -180,7 +202,9 @@ def enable(self, names: str | Iterable[str], ignoreInvalid: bool = False):
self.__cache__ = None
return result

def enableOnly(self, names: str | Iterable[str], ignoreInvalid: bool = False):
def enableOnly(
self, names: str | Iterable[str], ignoreInvalid: bool = False
) -> list[str]:
"""Enable rules with given names, and disable everything else.

:param names: name or list of rule names to enable.
Expand All @@ -192,9 +216,11 @@ def enableOnly(self, names: str | Iterable[str], ignoreInvalid: bool = False):
names = [names]
for rule in self.__rules__:
rule.enabled = False
self.enable(names, ignoreInvalid)
return self.enable(names, ignoreInvalid)

def disable(self, names: str | Iterable[str], ignoreInvalid: bool = False):
def disable(
self, names: str | Iterable[str], ignoreInvalid: bool = False
) -> list[str]:
"""Disable rules with given names.

:param names: name or list of rule names to enable.
Expand Down