欢迎光临散文网 会员登陆 & 注册

北太天元插件开发实现sortrow的一个例子(非常简单)

2023-05-17 21:17 作者:卢朓  | 我要投稿

/**

 * 北太天元的插件开发示例,这个还不是完整的,我这几天的视频的想法是

 * 每次介绍很少的几个知识点,面向的对象是对c++11之后的特性不熟悉的同学。

 * 增加了用 std::shared_ptr<void> 回收不同类型对象的内存的功能

 * 增加支持更多类型的输入和输出参数

 * 实现了一个sortrow 功能的插件函数

 */


#include <memory>

#include <cstdlib>

#include <iostream>

#include <typeinfo>

#include <utility>

#include <array>

#include <map>

#include <vector>

#include <algorithm>

#include <functional>

#include "bex/bex.hpp"

#include "get_data_from_bxArray.h"

#include "set_data_to_bxArray.h"


namespace ParseParams {

   template <class _T>

      class FunTrait;

   template <typename R, typename... Args>

      class FunTrait<R(Args...)>{

         public:

            static constexpr size_t n_args = sizeof...(Args);

            static constexpr const std::array<const std::type_info *, n_args> infos = {&typeid(Args)...};


         public:

         int required_params;

         std::vector<std::shared_ptr<void>> trash_bin;

         std::array<void *, n_args> passed_args_ptr;

         //变量类型函数句柄, 变量名是decorated_func

         R(*decorated_func)

            (Args...);



         public:

         FunTrait(R (*func)(Args...), int num_required = 0){

            decorated_func = func;

          required_params = num_required;

         }


         template <size_t... I>

         R eval_impl(std::index_sequence<I...>){

            return decorated_func((Args)passed_args_ptr[I]...);

         }


         R eval(){

            return eval_impl(std::make_index_sequence<n_args>());

         }


         int check_in_args_type(int nrhs, const bxArray * prhs[]){

            for(size_t i= 0; i< n_args; i++){

               if(infos[i]->name() == typeid(const double *).name()){

                  if(!bxIsDouble(prhs[i])){

                     bxPrintf("第%d输入参数必须是double类型",i);

                     return 1;

                  }

                  passed_args_ptr[i] = (void *)(bxGetDoubles(prhs[i]));

               }   

               else if(infos[i]->name() == typeid(const int32_t *).name()){

                  if(!bxIsInt32(prhs[i])){

                     bxPrintf("第%d输入参数必须是int32类型",i);

                     return 1;

                  }

                  passed_args_ptr[i] = (void *)(bxGetInt32s(prhs[i]));

               }   

               else if(infos[i]->name() == typeid(const int64_t *).name()){

                  if(!bxIsInt64(prhs[i])){

                     bxPrintf("第%d输入参数必须是int64类型",i);

                     return 1;

                  }

                  passed_args_ptr[i] = (void *)(bxGetInt64s(prhs[i]));

               }   

               else if(infos[i]->name() == typeid(const std::string *).name()){

                  if(!bxIsString(prhs[i])){

                     bxPrintf("第%d输入参数必须是string类型",i);

                     return 1;

                  }

                  passed_args_ptr[i] = (void *)(new std::string(bxGetStringDataPr(prhs[i])));

                  trash_bin.push_back(std::shared_ptr<std::string>((std::string *)passed_args_ptr[i]));

               }

               else if(infos[i]->name() == typeid(const std::vector<int> *).name()) {

                  if(!bxIsInt32(prhs[i])){

                     bxPrintf("第%d输入参数必须是int32 mat类型",i);

                     return 1;

                  }

                  std::vector<int> * x = new std::vector<int>();

                  get_vector_int(*x, prhs[i]);

                  passed_args_ptr[i] = (void *)x;

                  trash_bin.push_back(std::shared_ptr<std::vector<int>>(

                           (std::vector<int> *)passed_args_ptr[i]));

               }

               else if(infos[i]->name() == typeid(const std::vector<double> *).name()) {

                  if(!bxIsDouble(prhs[i])){

                     bxPrintf("第%d输入参数必须是double mat类型",i);

                     return 1;

                  }

                  std::vector<double> * x = new std::vector<double>();

                  get_vector_double(*x, prhs[i]);

                  passed_args_ptr[i] = (void *)x;

                  trash_bin.push_back(std::shared_ptr<std::vector<double>>(

                           (std::vector<double> *)passed_args_ptr[i]));

               }

               else if(infos[i]->name() == typeid(const std::vector<std::vector<double>> *).name()) {

                  if(!bxIsDouble(prhs[i])){

                     bxPrintf("第%d输入参数必须是double mat类型",i);

                     return 1;

                  }

                  std::vector<std::vector<double>> * x = new std::vector<std::vector<double>>();

                  get_matrix_double(*x, prhs[i]);

                  passed_args_ptr[i] = (void *)x;

                  trash_bin.push_back(std::shared_ptr<std::vector<std::vector<double>>>(

                           (std::vector<std::vector<double>> *)passed_args_ptr[i]));

               }

               else {

                     bxPrintf("第%d输入参数类型不对",i);

                     return 1;

               }

            }

            return 0;

         }


         void return_to_bxArray(R result, int nlhs, bxArray *plhs[]){

            if(nlhs <= 0 ) return;

            if constexpr (std::is_same<char, R>::value){

               char tmp[2] ={result, '\0'};

               plhs[0] = bxCreateString(tmp);   

            }   

            else if constexpr (std::is_same<int32_t,R>::value){

               plhs[0] = bxCreateInt32Scalar(result);

            }

            else if constexpr (std::is_same<double,R>::value){

               plhs[0] = bxCreateDoubleMatrix(1,1,bxREAL);

               double * ptr = bxGetDoubles(plhs[0]);

               *ptr = result;

            }   

            else if constexpr (std::is_same<std::string,R>::value){

               plhs[0] = bxCreateString(result.c_str());

            }

            else if constexpr (std::is_same<std::vector<double>, R>::value){

               set_vector_double(result, plhs[0]);

            }

            else if constexpr (std::is_same<std::vector<std::vector<double>>, R>::value){

               set_matrix_double(result, plhs[0]);

            }

            else if constexpr (std::is_same<std::vector<std::string>, R>::value){

               set_vector_string(result, plhs[0]);

            }

         }


      };

}


template<typename T>

class my_less{

   public:

   std::vector<int> ind;

   my_less(const std::vector<int>& _ind){

      ind = _ind;

   }

   bool operator()(const std::vector<T> & s, const std::vector<T> & t){

      for(auto i : ind){

         if(s[i] < t[i]){

            return true;

         }

         else if(s[i] > t[i])

            return false;

      }

      return false;

   }

};


std::vector<std::vector<double>> sortrow(

      const std::vector<std::vector<double>> * v,

   const std::vector<int> * ind   ){

   size_t m = v->size();

   size_t n = 0;

   if(m > 0)

      n = (*v)[0].size();


   std::vector<std::vector<double>> r(m);

   for(size_t i=0;i<m;i++){

      r[i].resize(n);

      for(size_t j = 0;j<n;j++){

         r[i][j] = (*v)[i][j];

      }

   }

   std::vector<int> ba_ind(*ind);

   for(auto & it : ba_ind) it -= 1;

   my_less<double> less_a(ba_ind);


   std::sort(r.begin(), r.end(), less_a);



   return r;

}


void cmd_sortrow(int nlhs, bxArray *plhs[], int nrhs, const bxArray *prhs[]) {

   ParseParams::FunTrait<decltype(sortrow)> q(sortrow,0);

   if(nrhs < q.n_args ){

      bxPrintf("输入参数%d < %d", nrhs, q.n_args);

      return;

   }   

   if(0 != q.check_in_args_type(nrhs, prhs)){

      bxPrintf("输入参数赋值出错\n");

      return;

   }

   auto result = q.eval();

   q.return_to_bxArray(result, nlhs, plhs);


}



static bexfun_info_t flist[] = {

 {"sortrow", cmd_sortrow, nullptr},

 {"", nullptr, nullptr},

};


bexfun_info_t *bxPluginFunctions() {

 return flist;

}



北太天元插件开发实现sortrow的一个例子(非常简单)的评论 (共 条)

分享到微博请遵守国家法律