00001 #ifndef _DLSVMCLASSIFIER_H_ 00002 #define _DLSVMCLASSIFIER_H_ 00003 00004 #include "svm.h" 00005 #include <string> 00006 #include <vector> 00007 #include <iostream> 00008 #include <algorithm> 00009 #include "DLException.h" 00010 00011 using namespace std; 00012 00033 class DLSVMClassifier 00034 { 00035 public: 00036 00040 DLSVMClassifier(); 00041 00049 DLSVMClassifier(string model_filename, bool bTesting = true); 00050 00054 ~DLSVMClassifier(); 00055 00064 int batchTraining(vector< vector<double> > data, vector<int> dataclass); 00065 00074 vector< vector<double> > batchTesting(vector< vector<double> > data, vector<int> dataclass); 00075 00081 int test(vector<double> testdata); 00082 00088 int saveConfusionMatrix(string strFileName); 00089 00090 private: 00091 00093 struct svm_parameter param; 00094 00096 struct svm_problem prob; 00097 struct svm_model *model; 00098 struct svm_node *x_space; 00099 int cross_validation; 00100 int nr_fold; 00101 const char *error_msg; 00102 00103 string model_file_name; 00104 00106 vector< vector<double> > confusionMatrix; 00107 }; 00108 00109 #endif