北太天元插件开发实现sortrow的一个例子(非常简单)

/**
* 北太天元的插件开发示例,这个还不是完整的,我这几天的视频的想法是
* 每次介绍很少的几个知识点,面向的对象是对c++11之后的特性不熟悉的同学。
* 增加了用 std::shared_ptr<void> 回收不同类型对象的内存的功能
* 增加支持更多类型的输入和输出参数
* 实现了一个sortrow 功能的插件函数
*/
#include <memory>
#include <cstdlib>
#include <iostream>
#include <typeinfo>
#include <utility>
#include <array>
#include <map>
#include <vector>
#include <algorithm>
#include <functional>
#include "bex/bex.hpp"
#include "get_data_from_bxArray.h"
#include "set_data_to_bxArray.h"
namespace ParseParams {
template <class _T>
class FunTrait;
template <typename R, typename... Args>
class FunTrait<R(Args...)>{
public:
static constexpr size_t n_args = sizeof...(Args);
static constexpr const std::array<const std::type_info *, n_args> infos = {&typeid(Args)...};
public:
int required_params;
std::vector<std::shared_ptr<void>> trash_bin;
std::array<void *, n_args> passed_args_ptr;
//变量类型函数句柄, 变量名是decorated_func
R(*decorated_func)
(Args...);
public:
FunTrait(R (*func)(Args...), int num_required = 0){
decorated_func = func;
required_params = num_required;
}
template <size_t... I>
R eval_impl(std::index_sequence<I...>){
return decorated_func((Args)passed_args_ptr[I]...);
}
R eval(){
return eval_impl(std::make_index_sequence<n_args>());
}
int check_in_args_type(int nrhs, const bxArray * prhs[]){
for(size_t i= 0; i< n_args; i++){
if(infos[i]->name() == typeid(const double *).name()){
if(!bxIsDouble(prhs[i])){
bxPrintf("第%d输入参数必须是double类型",i);
return 1;
}
passed_args_ptr[i] = (void *)(bxGetDoubles(prhs[i]));
}
else if(infos[i]->name() == typeid(const int32_t *).name()){
if(!bxIsInt32(prhs[i])){
bxPrintf("第%d输入参数必须是int32类型",i);
return 1;
}
passed_args_ptr[i] = (void *)(bxGetInt32s(prhs[i]));
}
else if(infos[i]->name() == typeid(const int64_t *).name()){
if(!bxIsInt64(prhs[i])){
bxPrintf("第%d输入参数必须是int64类型",i);
return 1;
}
passed_args_ptr[i] = (void *)(bxGetInt64s(prhs[i]));
}
else if(infos[i]->name() == typeid(const std::string *).name()){
if(!bxIsString(prhs[i])){
bxPrintf("第%d输入参数必须是string类型",i);
return 1;
}
passed_args_ptr[i] = (void *)(new std::string(bxGetStringDataPr(prhs[i])));
trash_bin.push_back(std::shared_ptr<std::string>((std::string *)passed_args_ptr[i]));
}
else if(infos[i]->name() == typeid(const std::vector<int> *).name()) {
if(!bxIsInt32(prhs[i])){
bxPrintf("第%d输入参数必须是int32 mat类型",i);
return 1;
}
std::vector<int> * x = new std::vector<int>();
get_vector_int(*x, prhs[i]);
passed_args_ptr[i] = (void *)x;
trash_bin.push_back(std::shared_ptr<std::vector<int>>(
(std::vector<int> *)passed_args_ptr[i]));
}
else if(infos[i]->name() == typeid(const std::vector<double> *).name()) {
if(!bxIsDouble(prhs[i])){
bxPrintf("第%d输入参数必须是double mat类型",i);
return 1;
}
std::vector<double> * x = new std::vector<double>();
get_vector_double(*x, prhs[i]);
passed_args_ptr[i] = (void *)x;
trash_bin.push_back(std::shared_ptr<std::vector<double>>(
(std::vector<double> *)passed_args_ptr[i]));
}
else if(infos[i]->name() == typeid(const std::vector<std::vector<double>> *).name()) {
if(!bxIsDouble(prhs[i])){
bxPrintf("第%d输入参数必须是double mat类型",i);
return 1;
}
std::vector<std::vector<double>> * x = new std::vector<std::vector<double>>();
get_matrix_double(*x, prhs[i]);
passed_args_ptr[i] = (void *)x;
trash_bin.push_back(std::shared_ptr<std::vector<std::vector<double>>>(
(std::vector<std::vector<double>> *)passed_args_ptr[i]));
}
else {
bxPrintf("第%d输入参数类型不对",i);
return 1;
}
}
return 0;
}
void return_to_bxArray(R result, int nlhs, bxArray *plhs[]){
if(nlhs <= 0 ) return;
if constexpr (std::is_same<char, R>::value){
char tmp[2] ={result, '\0'};
plhs[0] = bxCreateString(tmp);
}
else if constexpr (std::is_same<int32_t,R>::value){
plhs[0] = bxCreateInt32Scalar(result);
}
else if constexpr (std::is_same<double,R>::value){
plhs[0] = bxCreateDoubleMatrix(1,1,bxREAL);
double * ptr = bxGetDoubles(plhs[0]);
*ptr = result;
}
else if constexpr (std::is_same<std::string,R>::value){
plhs[0] = bxCreateString(result.c_str());
}
else if constexpr (std::is_same<std::vector<double>, R>::value){
set_vector_double(result, plhs[0]);
}
else if constexpr (std::is_same<std::vector<std::vector<double>>, R>::value){
set_matrix_double(result, plhs[0]);
}
else if constexpr (std::is_same<std::vector<std::string>, R>::value){
set_vector_string(result, plhs[0]);
}
}
};
}
template<typename T>
class my_less{
public:
std::vector<int> ind;
my_less(const std::vector<int>& _ind){
ind = _ind;
}
bool operator()(const std::vector<T> & s, const std::vector<T> & t){
for(auto i : ind){
if(s[i] < t[i]){
return true;
}
else if(s[i] > t[i])
return false;
}
return false;
}
};
std::vector<std::vector<double>> sortrow(
const std::vector<std::vector<double>> * v,
const std::vector<int> * ind ){
size_t m = v->size();
size_t n = 0;
if(m > 0)
n = (*v)[0].size();
std::vector<std::vector<double>> r(m);
for(size_t i=0;i<m;i++){
r[i].resize(n);
for(size_t j = 0;j<n;j++){
r[i][j] = (*v)[i][j];
}
}
std::vector<int> ba_ind(*ind);
for(auto & it : ba_ind) it -= 1;
my_less<double> less_a(ba_ind);
std::sort(r.begin(), r.end(), less_a);
return r;
}
void cmd_sortrow(int nlhs, bxArray *plhs[], int nrhs, const bxArray *prhs[]) {
ParseParams::FunTrait<decltype(sortrow)> q(sortrow,0);
if(nrhs < q.n_args ){
bxPrintf("输入参数%d < %d", nrhs, q.n_args);
return;
}
if(0 != q.check_in_args_type(nrhs, prhs)){
bxPrintf("输入参数赋值出错\n");
return;
}
auto result = q.eval();
q.return_to_bxArray(result, nlhs, plhs);
}
static bexfun_info_t flist[] = {
{"sortrow", cmd_sortrow, nullptr},
{"", nullptr, nullptr},
};
bexfun_info_t *bxPluginFunctions() {
return flist;
}