From c07bcda055042f4d5c7928e2ac8f8825e0320d9a Mon Sep 17 00:00:00 2001
From: ali <ali.farnudi@ens-lyon.fr>
Date: Mon, 18 Mar 2024 21:40:25 +0100
Subject: [PATCH] #3 add function that returns 1d array

---
 src/main.cpp                   | 13 +++++++++++++
 src/pybind_example/__init__.py |  2 ++
 tests/test_basic.py            |  7 ++++++-
 3 files changed, 21 insertions(+), 1 deletion(-)

diff --git a/src/main.cpp b/src/main.cpp
index 9e3a86f..b5fac6a 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -32,6 +32,18 @@ double sum_2d_array(const std::vector<std::vector<double>> &arr) {
     return sum;
 }
 
+std::vector<double> sum_rows(const std::vector<std::vector<double>> &arr) {
+    std::vector<double> row_sums;
+    for (const auto &row : arr) {
+        double row_sum = 0;
+        for (double num : row) {
+            row_sum += num;
+        }
+        row_sums.push_back(row_sum);
+    }
+    return row_sums;
+}
+
 
 namespace py = pybind11;
 
@@ -68,6 +80,7 @@ PYBIND11_MODULE(_core, m) {
 
     m.def("sum_array", &sum_array, "Calculate the sum of elements in an array");
     m.def("sum_2d_array", &sum_2d_array, "Calculate the sum of elements in a 2D array");
+    m.def("sum_rows", &sum_rows, "Calculate the sum of elements in each row of a 2D array");
 
 #ifdef VERSION_INFO
     m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO);
diff --git a/src/pybind_example/__init__.py b/src/pybind_example/__init__.py
index bd58be8..38a0f38 100644
--- a/src/pybind_example/__init__.py
+++ b/src/pybind_example/__init__.py
@@ -8,6 +8,7 @@ from ._core import (
     multiply,
     sum_array,
     sum_2d_array,
+    sum_rows,
 )
 
 __all__ = [
@@ -18,4 +19,5 @@ __all__ = [
     "multiply",
     "sum_array",
     "sum_2d_array",
+    "sum_rows",
 ]
diff --git a/tests/test_basic.py b/tests/test_basic.py
index 55efadf..4ff6cde 100644
--- a/tests/test_basic.py
+++ b/tests/test_basic.py
@@ -32,5 +32,10 @@ def test_sum_array_numpy_array_1D_int():
 
 
 def test_sum_array_numpy_array_2D_float():
-    array = np.arange(9).reshape((3,3))
+    array = np.arange(9).reshape((3, 3))
     assert m.sum_2d_array(array) == np.sum(array)
+
+
+def test_sum_rows_numpy_array():
+    array = np.arange(9).reshape((3, 3))
+    np.testing.assert_array_almost_equal(m.sum_rows(array), np.sum(array, axis=1))
-- 
GitLab