Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
139 commits
Select commit Hold shift + click to select a range
eee1fc4
initial files for louvain
icfaust May 21, 2024
51c35a9
start switch to louvain
icfaust May 21, 2024
623b265
update
icfaust May 22, 2024
e2a05df
Merge branch 'intel:main' into dev/louvain
icfaust Jun 3, 2024
ff51119
very much incomplete
icfaust Jun 3, 2024
3d09c58
fix vertex_partitioning dispatch
icfaust Jun 5, 2024
4fd8a07
compute -> vertex_paritioning
icfaust Jun 5, 2024
5bdbdf0
fix in louvain.cpp
icfaust Jun 5, 2024
ad4f31b
Merge branch 'intel:main' into dev/louvain
icfaust Jun 5, 2024
6fe7853
interim changes
icfaust Jun 5, 2024
68d91d4
finally compiles, error in oneDAL
icfaust Jun 7, 2024
eeab735
remove vestigial code
icfaust Jun 7, 2024
b918399
add warning
icfaust Jun 7, 2024
0a40c09
fix python code
icfaust Jun 7, 2024
ebbb29e
change copyright
icfaust Jun 7, 2024
c26c860
Merge branch 'intel:main' into dev/louvain
icfaust Jun 11, 2024
09a043d
update
icfaust Jun 11, 2024
66ed008
update copyright
icfaust Jun 11, 2024
221d725
switch name
icfaust Jun 11, 2024
3ef91a6
halfway there
icfaust Jun 11, 2024
6c6dbda
add graph allocator
icfaust Jun 12, 2024
c6c8734
interim fixes
icfaust Jun 12, 2024
97dba43
able to compile
icfaust Jun 12, 2024
ac69c29
attempt to fix undefined symbol
icfaust Jun 12, 2024
eae1694
one fixed for missing symbols
icfaust Jun 19, 2024
a206367
Merge branch 'intel:main' into dev/louvain
icfaust Jul 15, 2024
91f18e3
now compiles and loads in python
icfaust Jul 16, 2024
a3eaad4
change language
icfaust Jul 16, 2024
98a049c
formatting
icfaust Jul 16, 2024
8a76b01
improvements in python-side
icfaust Jul 16, 2024
3c34a2d
move comment
icfaust Jul 16, 2024
8a7b356
forgotten to_table
icfaust Jul 16, 2024
d77f289
change default
icfaust Jul 16, 2024
0e7a6a3
add informative message to assert
icfaust Jul 16, 2024
0af5305
renaming
icfaust Jul 16, 2024
4bc7370
isort and formatting
icfaust Jul 16, 2024
d948cc8
removing unnecessary code
icfaust Jul 16, 2024
939ad77
simplify for testing
icfaust Jul 16, 2024
7398b2e
add assert
icfaust Jul 16, 2024
a6f1f90
fix module call
icfaust Jul 16, 2024
8b93879
on->for
icfaust Jul 16, 2024
396fd45
add __init__
icfaust Jul 16, 2024
cb92739
fixes
icfaust Jul 16, 2024
6a69e68
need to work on the graph verification
icfaust Jul 16, 2024
1acd939
updates
icfaust Jul 17, 2024
fc29329
Merge branch 'intel:main' into dev/louvain
icfaust Jul 25, 2024
1c79029
interim with attempts at proper graphs
icfaust Jul 25, 2024
89cb147
first light
icfaust Jul 25, 2024
1d52657
remove prints
icfaust Jul 25, 2024
9698364
add tests
icfaust Jul 26, 2024
ec4e301
odds and ends
icfaust Jul 26, 2024
4b33064
formatting
icfaust Jul 26, 2024
5da59b0
intermediate save
icfaust Jul 26, 2024
3bb353c
add _louvain
icfaust Jul 26, 2024
bdfff42
nearing completion for testing
icfaust Jul 26, 2024
6ae03e0
properly share values between objects
icfaust Jul 26, 2024
6d23c8e
forgotten change
icfaust Jul 26, 2024
7bf6dbf
reformatting
icfaust Jul 26, 2024
3145abf
changes needed for clarity
icfaust Jul 26, 2024
d126eea
Merge branch 'intel:main' into dev/louvain
icfaust Jul 26, 2024
659c836
switch to nearest_neighbors
icfaust Jul 26, 2024
23234a9
formatting
icfaust Jul 26, 2024
e839004
address missing sklearn_check_version
icfaust Jul 26, 2024
884a0bd
forgotten check_array
icfaust Jul 26, 2024
68de824
remove patching
icfaust Jul 26, 2024
c2335f1
Update _louvain.py
icfaust Jul 26, 2024
26a3a4f
Update _louvain.py
icfaust Jul 26, 2024
db9159d
formatting
icfaust Jul 26, 2024
bfb5a34
Update test_louvain.py
icfaust Jul 26, 2024
4f2e224
Update test_louvain.py
icfaust Jul 26, 2024
a1ddbd0
Update _louvain.py
icfaust Jul 26, 2024
3beb349
Update _louvain.py
icfaust Jul 26, 2024
1fe7809
Update _louvain.py
icfaust Jul 26, 2024
303482d
Update dispatcher.py
icfaust Jul 26, 2024
45c8957
Update dispatcher.py
icfaust Jul 26, 2024
223e2fd
Update louvain.py
icfaust Jul 27, 2024
108824a
Update test_louvain.py
icfaust Jul 27, 2024
ab41cc6
Update test_louvain.py
icfaust Jul 27, 2024
3d6b1ca
Update test_louvain.py
icfaust Jul 27, 2024
2590c76
Update test_louvain.py
icfaust Jul 27, 2024
ebba246
X -> self.affinity_matrix_
icfaust Jul 29, 2024
9e8e83a
formatting
icfaust Jul 29, 2024
4ce716f
improve documentation
icfaust Jul 29, 2024
1b48fbb
re-enable decrefs for memory leaks
icfaust Jul 29, 2024
5a66ae5
won't compile otherwise
icfaust Jul 29, 2024
f1350d1
disable y support temporarily
icfaust Jul 29, 2024
5429f00
formatting
icfaust Jul 29, 2024
4f11ba0
decref segfaults, need to look deeper
icfaust Jul 29, 2024
675f71a
updates for common tests
icfaust Jul 29, 2024
e975658
Update _louvain.py
icfaust Jul 29, 2024
fa8a008
attempt to remove memory leak bluntly
icfaust Jul 30, 2024
4b18bbc
roughly check to_graph
icfaust Jul 30, 2024
ce3f954
formatting
icfaust Jul 30, 2024
d3781f2
Update data_conversion.cpp
icfaust Jul 30, 2024
34b8ad4
Update louvain.py
icfaust Jul 31, 2024
c3a5a8e
Update louvain.py
icfaust Jul 31, 2024
9edbef2
move to full copy
icfaust Jul 31, 2024
e1ebcb1
decref
icfaust Jul 31, 2024
585df0d
re-enable for last testing before preview
icfaust Jul 31, 2024
5eb396a
revert
icfaust Jul 31, 2024
6d851d3
Merge branch 'intel:main' into dev/louvain
icfaust Aug 1, 2024
3d390eb
fixes
icfaust Aug 1, 2024
f7d87b4
cols -> col
icfaust Aug 1, 2024
ddc2064
revert to original
icfaust Aug 1, 2024
e819091
move to preview
icfaust Aug 1, 2024
c7d6e04
readd missing file
icfaust Aug 1, 2024
28acde1
remove _louvain
icfaust Aug 1, 2024
d8178a0
fix isort
icfaust Aug 1, 2024
6acd6f7
add input name
icfaust Aug 1, 2024
b1c7a5e
deactivate check
icfaust Aug 1, 2024
f34969b
updates for preview
icfaust Aug 1, 2024
59e9cac
Add eps parameter for truncation
icfaust Aug 2, 2024
09deffb
Merge branch 'intel:main' into dev/louvain
icfaust Aug 2, 2024
273c916
Merge branch 'intel:main' into dev/louvain
icfaust Aug 9, 2024
f493d9e
add fixes for pairwise_kernels
icfaust Aug 9, 2024
6c40348
additional changes
icfaust Aug 9, 2024
3c66317
Merge branch 'intel:main' into dev/louvain
icfaust Sep 3, 2024
33f79e5
merge main
icfaust Oct 11, 2024
753738f
readd imports
icfaust Oct 11, 2024
5d19d2e
forgotten file
icfaust Oct 11, 2024
c3939ac
formatting
icfaust Oct 11, 2024
4c30495
Update dispatcher.py
icfaust Oct 11, 2024
d1be448
Update setup.py
icfaust Oct 11, 2024
4bcefd1
Merge branch 'intel:main' into dev/louvain
icfaust Oct 22, 2024
97d9515
Merge branch 'main' into dev/louvain
icfaust Nov 9, 2024
0944edc
Update _data_conversion.py
icfaust Nov 9, 2024
68d3d2e
Update _data_conversion.py
icfaust Nov 9, 2024
5f340b9
Merge branch 'main' into dev/louvain
icfaust Nov 15, 2024
f37413c
Merge branch 'intel:main' into dev/louvain
icfaust Nov 25, 2024
544b9d8
will probably not compile
icfaust Nov 27, 2024
bae7fb0
work on CI failures
icfaust Nov 27, 2024
558b6cf
doesn't need lambda, needs template function
icfaust Nov 27, 2024
fe31b9d
still not compiling
icfaust Nov 27, 2024
6fc7680
static_cast -> reinterpret_cast
icfaust Nov 27, 2024
430011d
fighting with the const
icfaust Nov 27, 2024
d276446
add prints for laziness
icfaust Nov 27, 2024
e63f6ec
Merge branch 'main' into dev/louvain
icfaust Feb 9, 2025
6d7420d
Update louvain.cpp
icfaust Feb 9, 2025
93932c5
Update louvain.cpp
icfaust Feb 9, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion onedal/cluster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@

from .dbscan import DBSCAN
from .kmeans import KMeans, k_means
from .louvain import Louvain

__all__ = ["DBSCAN", "KMeans", "k_means"]
__all__ = ["DBSCAN", "KMeans", "k_means", "Louvain"]

if daal_check_version((2023, "P", 200)):
from .kmeans_init import KMeansInit, kmeans_plusplus
Expand Down
220 changes: 220 additions & 0 deletions onedal/cluster/louvain.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
/*******************************************************************************
* Copyright 2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include <type_traits>
#include <iostream>
#include "oneapi/dal/algo/louvain.hpp"

#include "oneapi/dal/graph/undirected_adjacency_vector_graph.hpp"
#include "oneapi/dal/graph/common.hpp"
#include "oneapi/dal/detail/memory.hpp"

#include "onedal/common.hpp"
#include "onedal/version.hpp"
#include "onedal/datatypes/numpy/data_conversion.hpp"

namespace py = pybind11;

namespace oneapi::dal::python {

template <typename Float>
using graph_t = dal::preview::undirected_adjacency_vector_graph<std::int32_t, Float>;

template <typename Type>
inline void _table_checks(const table& input, Type* &ptr, std::int64_t &length) {
std::cout << "start_table_check\n";
if (input.get_kind() == dal::homogen_table::kind()){
const auto &homogen_input = static_cast<const dal::homogen_table &>(input);
// verify matching datatype
#define CHECK_DTYPE(CType) if (!std::is_same<Type, CType>::value) std::invalid_argument("Incorrect dtype");

SET_CTYPE_FROM_DAL_TYPE(homogen_input.get_metadata().get_data_type(0),
CHECK_DTYPE,
std::invalid_argument("Unknown table dtype"))


// verify only one column
if (homogen_input.get_column_count() != 1){
throw std::invalid_argument("Incorrect dimensions.");
}

// get length
length = static_cast<std::int64_t>(homogen_input.get_row_count());

// get pointer
auto bytes_array = dal::detail::get_original_data(homogen_input);
const bool is_mutable = bytes_array.has_mutable_data();

ptr = is_mutable ? reinterpret_cast<Type *>(bytes_array.get_mutable_data())
: const_cast<Type *>(reinterpret_cast<const Type *>(bytes_array.get_data()));

} else {
throw std::invalid_argument("Non-homogen table input.");
}
}

template <typename Float>
graph_t<Float> tables_to_undirected_graph(const table& data, const table& indices, const table& indptr){
// because oneDAL graphs do not allow have the ability to call python capsule destructors
// graphs cannot be directly created from numpy array types. Conversion from oneDAL
// tables makes a simple and consitent interface which matches other estimators. The
// csr table cannot be used because of the data type of the indicies and indptr which
// are hardcoded to int64 and because they are 1 indexed.
graph_t<Float> res;

Float *edge_ptr;
std::int32_t *cols;
std::int64_t *rows, data_count, col_count, vertex_count;

_table_checks<Float>(data, edge_ptr, data_count);
_table_checks<std::int32_t>(indices, cols, col_count);
_table_checks<std::int64_t>(indptr, rows, vertex_count);

// verify data and indices are same lengths
if (data_count != col_count){
throw std::invalid_argument("Got invalid csr object.");
}
// -1 needed to match oneDAL graph inputs
vertex_count--;

// Undirected graphs in oneDAL do not check for self-loops. This will iterate through
// the data to verify that nothing along the diagonal is stored in the csr format.
// This closely resembles scipy.sparse
std::int64_t N = col_count < vertex_count ? col_count : vertex_count;
std::cout << "access_problems\n";

for(std::int64_t u=0; u < N; ++u) {
std::int64_t row_begin = rows[u];
std::int64_t row_end = rows[u + 1];
for(std::int64_t j = row_begin; j < row_end; ++j){
if (cols[j] == u) {
throw std::invalid_argument(
"Self-loops are not allowed.\n");
}
}
}

auto& graph_impl = dal::detail::get_impl(res);
using vertex_set_t = typename dal::preview::graph_traits<graph_t<Float>>::vertex_set;
dal::preview::detail::rebinded_allocator ra(graph_impl._vertex_allocator);
auto [degrees_array, degrees] = ra.template allocate_array<vertex_set_t>(vertex_count);
for (std::int64_t u = 0; u < vertex_count; u++) {
degrees[u] = rows[u + 1] - rows[u];
}

graph_impl.set_topology(vertex_count, col_count/2, rows, cols, col_count, degrees);
graph_impl.set_edge_values(edge_ptr, col_count/2);
std::cout << "graph_generated\n";

return res;
}

template <typename Task, typename Ops>
struct method2t {
method2t(const Task& task, const Ops& ops) : ops(ops) {}

template <typename Float>
auto operator()(const py::dict& params) {
using namespace preview::louvain;

const auto method = params["method"].cast<std::string>();

ONEDAL_PARAM_DISPATCH_VALUE(method, "fast", ops, Float, method::fast);
ONEDAL_PARAM_DISPATCH_VALUE(method, "by_default", ops, Float, method::by_default);
ONEDAL_PARAM_DISPATCH_THROW_INVALID_VALUE(method);
}

Ops ops;
};

struct params2desc {
template <typename Float, typename Method, typename Task>
auto operator()(const pybind11::dict& params) {
using namespace dal::preview::louvain;

auto desc = descriptor<Float, Method, Task>();
desc.set_resolution(params["resolution"].cast<double>());
desc.set_accuracy_threshold(params["accuracy_threshold"].cast<double>());
desc.set_max_iteration_count(params["max_iteration_count"].cast<std::int64_t>());

return desc;
}
};

template <typename Task>
void init_vertex_partitioning_ops(py::module_& m) {
m.def("vertex_partitioning",
[](const py::dict& params,
const table& data,
const table& indices,
const table& indptr,
const table& initial_partition) {
using namespace preview::louvain;
using input_t = vertex_partitioning_input<graph_t<double>, Task>;
// create graphs from oneDAL tables
graph_t<double> graph;
// only int and double topologies are currently exported to the oneDAL shared object

graph = tables_to_undirected_graph<double>(data, indices, indptr);
std::cout << "running ops\n";
vertex_partitioning_ops ops(input_t{ graph, initial_partition}, params2desc{});
return fptype2t{ method2t{ Task{}, ops } }(params);
});
m.def("vertex_partitioning",
[](const py::dict& params,
const table& data,
const table& indices,
const table& indptr) {
using namespace preview::louvain;
using input_t = vertex_partitioning_input<graph_t<double>, Task>;
graph_t<double> graph;
// only int and double topologies are currently exported to the oneDAL shared object
graph = tables_to_undirected_graph<double>(data, indices, indptr);
std::cout << "running ops\n";
vertex_partitioning_ops ops(input_t{ graph }, params2desc{});
return fptype2t{ method2t{ Task{}, ops } }(params);
});
}

template <typename Task>
void init_vertex_partitioning_result(py::module_& m) {
using namespace preview::louvain;
using result_t = vertex_partitioning_result<Task>;

py::class_<result_t>(m, "vertex_paritioning_result")
.def(py::init())
.DEF_ONEDAL_PY_PROPERTY(labels, result_t)
.DEF_ONEDAL_PY_PROPERTY(modularity, result_t)
.DEF_ONEDAL_PY_PROPERTY(community_count, result_t);
}

ONEDAL_PY_TYPE2STR(preview::louvain::task::vertex_partitioning, "vertex_partitioning");

ONEDAL_PY_DECLARE_INSTANTIATOR(init_vertex_partitioning_ops);
ONEDAL_PY_DECLARE_INSTANTIATOR(init_vertex_partitioning_result);

ONEDAL_PY_INIT_MODULE(louvain) {
using namespace dal::detail;
using namespace dal::preview::louvain;

using task_list = types<task::vertex_partitioning>;
auto sub = m.def_submodule("louvain");

ONEDAL_PY_INSTANTIATE(init_vertex_partitioning_ops, sub, task_list);
ONEDAL_PY_INSTANTIATE(init_vertex_partitioning_result, sub, task_list);
}

} // namespace oneapi::dal::python
68 changes: 68 additions & 0 deletions onedal/cluster/louvain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# ===============================================================================
# Copyright 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===============================================================================

import warnings

import numpy as np

from daal4py.sklearn._utils import get_dtype

from ..common._base import BaseEstimator
from ..common._mixin import ClusterMixin
from ..datatypes import from_table, to_table
from ..utils.validation import _check_array, _check_X_y, _is_csr


class Louvain(BaseEstimator, ClusterMixin):

def __init__(self, resolution=1.0, *, tol=0.0001, max_iter=10):
self.resolution = resolution
self.tol = tol
self.max_iter = max_iter

def _get_onedal_params(self, dtype=np.float64):
return {
"fptype": dtype,
"method": "by_default",
"resolution": float(self.resolution),
"accuracy_threshold": float(self.tol),
"max_iteration_count": int(self.max_iter),
}

def fit(self, X, y=None, queue=None):
# queue is only included to match convention for all onedal estimators
if queue is not None:
warnings.warn("Louvain is implemented only for CPU")
assert _is_csr(X), "input must be CSR sparse"
# limitations in oneDAL's shared object force the topology to double type
if y is None:
X = _check_array(X, accept_sparse="csr", dtype=np.float64)
else:
X, y = _check_X_y(X, y, accept_sparse="csr", dtype=np.float64)
y = y.astype(np.int64) # restriction by oneDAL initial partition

module = self._get_backend("louvain", "vertex_partitioning", None)

# conversion of a scipy csr to dal csr_table will have incorrect dtypes and indices
# must be done via three tables with types double, int32, int64 for oneDAL graph type
data = to_table(X.data, X.indices, X.indptr) if y is None else to_table(X.data, X.indices, X.indptr, y)
params = self._get_onedal_params(data[0].dtype)
result = module.vertex_partitioning(params, *data)
self.labels_ = from_table(result.labels).ravel()
self.modularity_ = float(result.modularity)
self.community_count_ = int(result.community_count)
self.n_features_in_ = X.shape[1]
return self
Loading
Loading