From c07bcda055042f4d5c7928e2ac8f8825e0320d9a Mon Sep 17 00:00:00 2001 From: ali <ali.farnudi@ens-lyon.fr> Date: Mon, 18 Mar 2024 21:40:25 +0100 Subject: [PATCH] #3 add function that returns 1d array --- src/main.cpp | 13 +++++++++++++ src/pybind_example/__init__.py | 2 ++ tests/test_basic.py | 7 ++++++- 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/main.cpp b/src/main.cpp index 9e3a86f..b5fac6a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -32,6 +32,18 @@ double sum_2d_array(const std::vector<std::vector<double>> &arr) { return sum; } +std::vector<double> sum_rows(const std::vector<std::vector<double>> &arr) { + std::vector<double> row_sums; + for (const auto &row : arr) { + double row_sum = 0; + for (double num : row) { + row_sum += num; + } + row_sums.push_back(row_sum); + } + return row_sums; +} + namespace py = pybind11; @@ -68,6 +80,7 @@ PYBIND11_MODULE(_core, m) { m.def("sum_array", &sum_array, "Calculate the sum of elements in an array"); m.def("sum_2d_array", &sum_2d_array, "Calculate the sum of elements in a 2D array"); + m.def("sum_rows", &sum_rows, "Calculate the sum of elements in each row of a 2D array"); #ifdef VERSION_INFO m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO); diff --git a/src/pybind_example/__init__.py b/src/pybind_example/__init__.py index bd58be8..38a0f38 100644 --- a/src/pybind_example/__init__.py +++ b/src/pybind_example/__init__.py @@ -8,6 +8,7 @@ from ._core import ( multiply, sum_array, sum_2d_array, + sum_rows, ) __all__ = [ @@ -18,4 +19,5 @@ __all__ = [ "multiply", "sum_array", "sum_2d_array", + "sum_rows", ] diff --git a/tests/test_basic.py b/tests/test_basic.py index 55efadf..4ff6cde 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -32,5 +32,10 @@ def test_sum_array_numpy_array_1D_int(): def test_sum_array_numpy_array_2D_float(): - array = np.arange(9).reshape((3,3)) + array = np.arange(9).reshape((3, 3)) assert m.sum_2d_array(array) == np.sum(array) + + +def test_sum_rows_numpy_array(): + array = np.arange(9).reshape((3, 3)) + np.testing.assert_array_almost_equal(m.sum_rows(array), np.sum(array, axis=1)) -- GitLab