mex_train_ckn_cudnn.cpp 4.3 KB
Newer Older
MAIRAL Julien's avatar
MAIRAL Julien committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
#include <linalg.h>
#include <mexutils.h>
#include "common_cudnn.h"

template <typename T>
inline void getModel(mxArray* pr_model,Layer<T>*& layers, int& nlayers) {
   mxArray *pr_layers = mxGetField(pr_model,0,"layer");
   const mwSize* dims_layer=mxGetDimensions(pr_layers);
   nlayers=dims_layer[0]*dims_layer[1];
   layers = new Layer<T>[nlayers];
   for (int ii=0; ii<nlayers; ++ii) {
      mxArray* layer=mxGetCell(pr_layers,ii);
      layers[ii].num_layer=ii+1;
      layers[ii].npatch=getScalarStruct<int>(layer,"npatch");
      layers[ii].nfilters=getScalarStruct<int>(layer,"nfilters");
      layers[ii].subsampling=getScalarStruct<int>(layer,"subsampling");
      layers[ii].stride=getScalarStructDef<int>(layer,"stride",1);
      layers[ii].zero_padding=getScalarStructDef<bool>(layer,"zero_padding",false);
      layers[ii].type_layer=getScalarStruct<int>(layer,"type_layer");
      layers[ii].type_kernel=getScalarStruct<int>(layer,"type_kernel");
      layers[ii].sigma=getScalarStruct<T>(layer,"sigma");
      layers[ii].pooling_mode=getScalarStructDef<pooling_mode_t>(layer,"pooling_mode",POOL_GAUSSIAN_FILTER);
      mxArray *pr_W2 = mxGetField(layer,0,"W2");
      getMatrix(pr_W2,layers[ii].W2);
      mxArray *pr_W = mxGetField(layer,0,"W");
      getMatrix(pr_W,layers[ii].W);
      mxArray *pr_b = mxGetField(layer,0,"b");
      getVector(pr_b,layers[ii].b);
   };
};

template <typename Tin, typename T>
inline void callFunctionAux(mxArray* plhs[], const mxArray*prhs[], const int nlhs) {
   Map<Tin> X;
   getMap(prhs[0],X);
   Matrix<T> Y;
   getMatrix(prhs[1],Y);
   Map<Tin> Xval;
   getMap(prhs[2],Xval);
   Matrix<T> Yval;
   getMatrix(prhs[3],Yval);
   Layer<T>* layers;
   plhs[0]=mxDuplicateArray(prhs[4]);
   int nlayers;
   getModel(plhs[0],layers,nlayers);
   Matrix<T> W;
   plhs[1]=mxDuplicateArray(prhs[5]);
   getMatrix(plhs[1],W);
   Vector<T> b;
   plhs[2]=mxDuplicateArray(prhs[6]);
   getVector(plhs[2],b);
   ParamSGD<T> param;
   int threads = getScalarStructDef<int>(prhs[7],"threads",-1);
   const int device = getScalarStruct<int>(prhs[7],"device");
   param.lambda = getScalarStruct<T>(prhs[7],"lambda");
   param.lambda2 = getScalarStruct<T>(prhs[7],"lambda2");
   param.loss = getScalarStruct<loss_t>(prhs[7],"loss");
   param.epochs = getScalarStruct<int>(prhs[7],"epochs");
   param.batch_size = getScalarStruct<int>(prhs[7],"batch_size");
   param.momentum = getScalarStruct<T>(prhs[7],"momentum");
   param.eta = getScalarStruct<T>(prhs[7],"eta");
   param.scal_intercept=getScalarStruct<T>(prhs[7],"scal_intercept");
   param.update_Wb=getScalarStructDef<bool>(prhs[7],"update_Wb",true);
   param.update_model=getScalarStructDef<bool>(prhs[7],"update_model",true);
   param.update_miso=getScalarStructDef<bool>(prhs[7],"update_miso",false);
   param.preconditioning_model=getScalarStructDef<bool>(prhs[7],"preconditioning_model",false);
   param.learning_rate_mode=getScalarStructDef<int>(prhs[7],"learning_rate_mode",0);
   param.data_augmentation=getScalarStructDef<int>(prhs[7],"data_augmentation",0);
   param.it_eval=getScalarStructDef<int>(prhs[7],"it_eval",1);
   param.it_decrease=getScalarStructDef<int>(prhs[7],"it_decrease",10);
   param.active_set=getScalarStructDef<bool>(prhs[7],"active_set",false);
   init_cuda(device,true,true);
   Matrix<T> logs;
   if (nlhs==4) {
      plhs[3]=createMatrix<T>(3,param.epochs);
      getMatrix(plhs[3],logs);
   }
   sgd_solver_supervised(X,Y,Xval,Yval,layers,nlayers,W,b,param,logs);
   destroy_cuda(true,true);
   delete[](layers);
}
/// X Y Xval Yval model W b param
/// output: W b model
template <typename Tin>
inline void callFunction(mxArray* plhs[], const mxArray*prhs[], const int nlhs) {
   bool double_precision = getScalarStructDef<bool>(prhs[7],"double_precision",false);
   if (double_precision) {
      // TODO
   } else {
      callFunctionAux<Tin,float>(plhs,prhs,nlhs);
   }
}

void mexFunction(int nlhs, mxArray *plhs[],int nrhs, const mxArray *prhs[]) {
   if (nrhs != 8)
      mexErrMsgTxt("Bad number of inputs arguments");

   if (nlhs != 3 && nlhs != 4)
      mexErrMsgTxt("Bad number of output arguments");

   if (mxGetClassID(prhs[0]) == mxDOUBLE_CLASS) {
      callFunction<double>(plhs,prhs,nlhs);
   } else if (mxGetClassID(prhs[0]) == mxUINT8_CLASS) {
      callFunction<unsigned char>(plhs,prhs,nlhs);
   } else {
      callFunction<float>(plhs,prhs,nlhs);
   }
}