From 902b4cd35af1904bf5873e274669a2cb937d325d Mon Sep 17 00:00:00 2001
From: GD <gd.dev@libertymail.net>
Date: Wed, 9 Mar 2022 11:48:33 +0100
Subject: [PATCH] fix image file usage after modifying the test dir
 installation

---
 tests/test_dictLearn.py | 38 +++++++++++++++++++++++++++++++++-----
 1 file changed, 33 insertions(+), 5 deletions(-)

diff --git a/tests/test_dictLearn.py b/tests/test_dictLearn.py
index 88714a0..c6ea44a 100644
--- a/tests/test_dictLearn.py
+++ b/tests/test_dictLearn.py
@@ -1,6 +1,7 @@
 from __future__ import absolute_import, division, print_function
 
 import sys
+import os
 import numpy as np
 import scipy
 import scipy.sparse as ssp
@@ -15,6 +16,33 @@ if not ('rand' in ssp.__dict__):
     ssprand = myscipy_rand.rand
 else:
     ssprand = ssp.rand
+    
+    
+def get_img_file_path(img):
+    """Return path to an image file
+    
+    Arguments:
+        img (string): image filename without path among 'boat.png' or 
+        'lena.png'.
+        
+    Output:
+        img_file (string): normalized path to image input filename.
+    
+    """
+    # check input
+    if not img in ["boat.png", "lena.png"]:
+        raise ValueError("bad input, `img` should be 'boat.png' or 'lena.png'")
+    # try local file
+    img_file = os.path.join("data", img)
+    if os.path.isfile(img_file):
+        img_file = os.path.abspath(img_file)
+    else:
+        # file from install
+        img_file = os.path.join(
+            os.path.dirname(os.path.abspath(spams.__file__)), img_file
+        )
+    # output
+    return img_file
 
 
 def _extract_lasso_param(f_param):
@@ -46,8 +74,8 @@ def _objective(X, D, param, imgname=None):
 
 
 def test_trainDL():
-    img_file = 'data/boat.png'
     try:
+        img_file = get_img_file_path("boat.png")
         img = Image.open(img_file)
     except:
         print("Cannot load image %s : skipping test" % img_file)
@@ -140,8 +168,8 @@ def test_trainDL():
 
 
 def test_trainDL_Memory():
-    img_file = 'data/lena.png'
     try:
+        img_file = get_img_file_path("lena.png")
         img = Image.open(img_file)
     except:
         print("Cannot load image %s : skipping test" % img_file)
@@ -202,8 +230,8 @@ def test_trainDL_Memory():
 
 
 def test_structTrainDL():
-    img_file = 'data/lena.png'
     try:
+        img_file = get_img_file_path("lena.png")
         img = Image.open(img_file)
     except Exception as e:
         print("Cannot load image %s (%s) : skipping test" % (img_file, e))
@@ -376,8 +404,8 @@ def test_structTrainDL():
 
 
 def test_nmf():
-    img_file = 'data/boat.png'
     try:
+        img_file = get_img_file_path("boat.png")
         img = Image.open(img_file)
     except:
         print("Cannot load image %s : skipping test" % img_file)
@@ -413,8 +441,8 @@ def test_nmf():
 
 # Archetypal Analysis, run first steps with FISTA and run last steps with activeSet,
 def test_archetypalAnalysis():
-    img_file = 'data/lena.png'
     try:
+        img_file = get_img_file_path("lena.png")
         img = Image.open(img_file)
     except Exception as e:
         print("Cannot load image %s (%s) : skipping test" % (img_file, e))
-- 
GitLab