33# Copyright (c) https://github.com/PyCQA/pylint/blob/main/CONTRIBUTORS.txt
44
55"""Diagram objects."""
6+
7+ from __future__ import annotations
8+
9+ from collections .abc import Iterable
10+ from typing import Any
11+
612import astroid
713from astroid import nodes
814
1319class Figure :
1420 """Base class for counter handling."""
1521
22+ def __init__ (self ) -> None :
23+ self .fig_id : str = ""
24+
1625
1726class Relationship (Figure ):
1827 """A relationship from an object in the diagram to another."""
1928
20- def __init__ (self , from_object , to_object , relation_type , name = None ):
29+ def __init__ (
30+ self ,
31+ from_object : DiagramEntity ,
32+ to_object : DiagramEntity ,
33+ relation_type : str ,
34+ name : str | None = None ,
35+ ):
2136 super ().__init__ ()
2237 self .from_object = from_object
2338 self .to_object = to_object
@@ -28,59 +43,76 @@ def __init__(self, from_object, to_object, relation_type, name=None):
2843class DiagramEntity (Figure ):
2944 """A diagram object, i.e. a label associated to an astroid node."""
3045
31- def __init__ (self , title = "No name" , node = None ):
46+ default_shape = ""
47+
48+ def __init__ (
49+ self , title : str = "No name" , node : nodes .NodeNG | None = None
50+ ) -> None :
3251 super ().__init__ ()
3352 self .title = title
34- self .node = node
53+ self .node : nodes .NodeNG = node if node else nodes .NodeNG ()
54+ self .shape = self .default_shape
3555
3656
3757class PackageEntity (DiagramEntity ):
3858 """A diagram object representing a package."""
3959
60+ default_shape = "package"
61+
4062
4163class ClassEntity (DiagramEntity ):
4264 """A diagram object representing a class."""
4365
44- def __init__ (self , title , node ):
66+ default_shape = "class"
67+
68+ def __init__ (self , title : str , node : nodes .ClassDef ) -> None :
4569 super ().__init__ (title = title , node = node )
46- self .attrs = None
47- self .methods = None
70+ self .attrs : list [ str ] = []
71+ self .methods : list [ nodes . FunctionDef ] = []
4872
4973
5074class ClassDiagram (Figure , FilterMixIn ):
5175 """Main class diagram handling."""
5276
5377 TYPE = "class"
5478
55- def __init__ (self , title , mode ) :
79+ def __init__ (self , title : str , mode : str ) -> None :
5680 FilterMixIn .__init__ (self , mode )
5781 Figure .__init__ (self )
5882 self .title = title
59- self . objects = []
60- self .relationships = {}
61- self ._nodes = {}
62- self .depends = []
83+ # TODO: Specify 'Any' after refactor of `DiagramEntity`
84+ self .objects : list [ Any ] = []
85+ self .relationships : dict [ str , list [ Relationship ]] = {}
86+ self ._nodes : dict [ nodes . NodeNG , DiagramEntity ] = {}
6387
64- def get_relationships (self , role ) :
88+ def get_relationships (self , role : str ) -> Iterable [ Relationship ] :
6589 # sorted to get predictable (hence testable) results
6690 return sorted (
6791 self .relationships .get (role , ()),
6892 key = lambda x : (x .from_object .fig_id , x .to_object .fig_id ),
6993 )
7094
71- def add_relationship (self , from_object , to_object , relation_type , name = None ):
95+ def add_relationship (
96+ self ,
97+ from_object : DiagramEntity ,
98+ to_object : DiagramEntity ,
99+ relation_type : str ,
100+ name : str | None = None ,
101+ ) -> None :
72102 """Create a relationship."""
73103 rel = Relationship (from_object , to_object , relation_type , name )
74104 self .relationships .setdefault (relation_type , []).append (rel )
75105
76- def get_relationship (self , from_object , relation_type ):
106+ def get_relationship (
107+ self , from_object : DiagramEntity , relation_type : str
108+ ) -> Relationship :
77109 """Return a relationship or None."""
78110 for rel in self .relationships .get (relation_type , ()):
79111 if rel .from_object is from_object :
80112 return rel
81113 raise KeyError (relation_type )
82114
83- def get_attrs (self , node ) :
115+ def get_attrs (self , node : nodes . ClassDef ) -> list [ str ] :
84116 """Return visible attributes, possibly with class name."""
85117 attrs = []
86118 properties = [
@@ -101,7 +133,7 @@ def get_attrs(self, node):
101133 attrs .append (node_name )
102134 return sorted (attrs )
103135
104- def get_methods (self , node ) :
136+ def get_methods (self , node : nodes . ClassDef ) -> list [ nodes . FunctionDef ] :
105137 """Return visible methods."""
106138 methods = [
107139 m
@@ -113,14 +145,14 @@ def get_methods(self, node):
113145 ]
114146 return sorted (methods , key = lambda n : n .name )
115147
116- def add_object (self , title , node ) :
148+ def add_object (self , title : str , node : nodes . ClassDef ) -> None :
117149 """Create a diagram object."""
118150 assert node not in self ._nodes
119- ent = DiagramEntity (title , node )
151+ ent = ClassEntity (title , node )
120152 self ._nodes [node ] = ent
121153 self .objects .append (ent )
122154
123- def class_names (self , nodes_lst ) :
155+ def class_names (self , nodes_lst : Iterable [ nodes . NodeNG ]) -> list [ str ] :
124156 """Return class names if needed in diagram."""
125157 names = []
126158 for node in nodes_lst :
@@ -136,30 +168,26 @@ def class_names(self, nodes_lst):
136168 names .append (node_name )
137169 return names
138170
139- def nodes (self ):
140- """Return the list of underlying nodes."""
141- return self ._nodes .keys ()
142-
143- def has_node (self , node ):
171+ def has_node (self , node : nodes .NodeNG ) -> bool :
144172 """Return true if the given node is included in the diagram."""
145173 return node in self ._nodes
146174
147- def object_from_node (self , node ) :
175+ def object_from_node (self , node : nodes . NodeNG ) -> DiagramEntity :
148176 """Return the diagram object mapped to node."""
149177 return self ._nodes [node ]
150178
151- def classes (self ):
179+ def classes (self ) -> list [ ClassEntity ] :
152180 """Return all class nodes in the diagram."""
153- return [o for o in self .objects if isinstance (o . node , nodes . ClassDef )]
181+ return [o for o in self .objects if isinstance (o , ClassEntity )]
154182
155- def classe (self , name ) :
183+ def classe (self , name : str ) -> ClassEntity :
156184 """Return a class by its name, raise KeyError if not found."""
157185 for klass in self .classes ():
158186 if klass .node .name == name :
159187 return klass
160188 raise KeyError (name )
161189
162- def extract_relationships (self ):
190+ def extract_relationships (self ) -> None :
163191 """Extract relationships between nodes in the diagram."""
164192 for obj in self .classes ():
165193 node = obj .node
@@ -205,18 +233,25 @@ class PackageDiagram(ClassDiagram):
205233
206234 TYPE = "package"
207235
208- def modules (self ):
236+ def modules (self ) -> list [ PackageEntity ] :
209237 """Return all module nodes in the diagram."""
210- return [o for o in self .objects if isinstance (o . node , nodes . Module )]
238+ return [o for o in self .objects if isinstance (o , PackageEntity )]
211239
212- def module (self , name ) :
240+ def module (self , name : str ) -> PackageEntity :
213241 """Return a module by its name, raise KeyError if not found."""
214242 for mod in self .modules ():
215243 if mod .node .name == name :
216244 return mod
217245 raise KeyError (name )
218246
219- def get_module (self , name , node ):
247+ def add_object (self , title : str , node : nodes .Module ) -> None :
248+ """Create a diagram object."""
249+ assert node not in self ._nodes
250+ ent = PackageEntity (title , node )
251+ self ._nodes [node ] = ent
252+ self .objects .append (ent )
253+
254+ def get_module (self , name : str , node : nodes .Module ) -> PackageEntity :
220255 """Return a module by its name, looking also for relative imports;
221256 raise KeyError if not found.
222257 """
@@ -232,29 +267,29 @@ def get_module(self, name, node):
232267 return mod
233268 raise KeyError (name )
234269
235- def add_from_depend (self , node , from_module ) :
270+ def add_from_depend (self , node : nodes . ImportFrom , from_module : str ) -> None :
236271 """Add dependencies created by from-imports."""
237272 mod_name = node .root ().name
238273 obj = self .module (mod_name )
239274 if from_module not in obj .node .depends :
240275 obj .node .depends .append (from_module )
241276
242- def extract_relationships (self ):
277+ def extract_relationships (self ) -> None :
243278 """Extract relationships between nodes in the diagram."""
244279 super ().extract_relationships ()
245- for obj in self .classes ():
280+ for class_obj in self .classes ():
246281 # ownership
247282 try :
248- mod = self .object_from_node (obj .node .root ())
249- self .add_relationship (obj , mod , "ownership" )
283+ mod = self .object_from_node (class_obj .node .root ())
284+ self .add_relationship (class_obj , mod , "ownership" )
250285 except KeyError :
251286 continue
252- for obj in self .modules ():
253- obj .shape = "package"
287+ for package_obj in self .modules ():
288+ package_obj .shape = "package"
254289 # dependencies
255- for dep_name in obj .node .depends :
290+ for dep_name in package_obj .node .depends :
256291 try :
257- dep = self .get_module (dep_name , obj .node )
292+ dep = self .get_module (dep_name , package_obj .node )
258293 except KeyError :
259294 continue
260- self .add_relationship (obj , dep , "depends" )
295+ self .add_relationship (package_obj , dep , "depends" )
0 commit comments