@@ -1079,23 +1079,38 @@ PyLocation &DefaultingPyLocation::resolve() {
10791079PyModule::PyModule (PyMlirContextRef contextRef, MlirModule module )
10801080 : BaseContextObject(std::move(contextRef)), module (module ) {}
10811081
1082- PyModule::~PyModule () { mlirModuleDestroy (module ); }
1082+ PyModule::~PyModule () {
1083+ nb::gil_scoped_acquire acquire;
1084+ auto &liveModules = getContext ()->liveModules ;
1085+ assert (liveModules.count (module .ptr ) == 1 &&
1086+ " destroying module not in live map" );
1087+ liveModules.erase (module .ptr );
1088+ mlirModuleDestroy (module );
1089+ }
10831090
10841091PyModuleRef PyModule::forModule (MlirModule module ) {
10851092 MlirContext context = mlirModuleGetContext (module );
10861093 PyMlirContextRef contextRef = PyMlirContext::forContext (context);
10871094
1088- // Create.
1089- PyModule *unownedModule = new PyModule (std::move (contextRef), module );
1090- // Note that the default return value policy on cast is `automatic_reference`,
1091- // which means "does not take ownership, does not call delete/dtor".
1092- // We use `take_ownership`, which means "Python will call the C++ destructor
1093- // and delete operator when the Python wrapper is garbage collected", because
1094- // MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
1095- // etc).
1096- nb::object pyRef = nb::cast (unownedModule, nb::rv_policy::take_ownership);
1097- unownedModule->handle = pyRef;
1098- return PyModuleRef (unownedModule, std::move (pyRef));
1095+ nb::gil_scoped_acquire acquire;
1096+ auto &liveModules = contextRef->liveModules ;
1097+ auto it = liveModules.find (module .ptr );
1098+ if (it == liveModules.end ()) {
1099+ // Create.
1100+ PyModule *unownedModule = new PyModule (std::move (contextRef), module );
1101+ // Note that the default return value policy on cast is automatic_reference,
1102+ // which does not take ownership (delete will not be called).
1103+ // Just be explicit.
1104+ nb::object pyRef = nb::cast (unownedModule, nb::rv_policy::take_ownership);
1105+ unownedModule->handle = pyRef;
1106+ liveModules[module .ptr ] =
1107+ std::make_pair (unownedModule->handle , unownedModule);
1108+ return PyModuleRef (unownedModule, std::move (pyRef));
1109+ }
1110+ // Use existing.
1111+ PyModule *existing = it->second .second ;
1112+ nb::object pyRef = nb::borrow<nb::object>(it->second .first );
1113+ return PyModuleRef (existing, std::move (pyRef));
10991114}
11001115
11011116nb::object PyModule::createFromCapsule (nb::object capsule) {
@@ -2084,6 +2099,8 @@ PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
20842099 return PyInsertionPoint{block, std::move (nextOpRef)};
20852100}
20862101
2102+ size_t PyMlirContext::getLiveModuleCount () { return liveModules.size (); }
2103+
20872104nb::object PyInsertionPoint::contextEnter (nb::object insertPoint) {
20882105 return PyThreadContextEntry::pushInsertionPoint (insertPoint);
20892106}
@@ -2923,6 +2940,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
29232940 PyMlirContextRef ref = PyMlirContext::forContext (self.get ());
29242941 return ref.releaseObject ();
29252942 })
2943+ .def (" _get_live_module_count" , &PyMlirContext::getLiveModuleCount)
29262944 .def_prop_ro (MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
29272945 .def (MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
29282946 .def (" __enter__" , &PyMlirContext::contextEnter)
0 commit comments