As a practice for using Optuna, I quickly wrote a text classifier using Doc2Vec.
The output includes the model's accuracy, F1 score, and a Confusion Matrix using pyplot.
Prepare the label data in the following format:
| DOCUMENT_FILE_NAME(id) | LABEL(labels) |
|---|---|
| foo.txt | bar |
| bar.txt | foo |
For usage, run % python document_SVClassifier.py -h.
Classification is performed for each text data.
Source (github)
1# Author: Atsuya Kobayashi @atsuya_kobayashi 2# 2019/02/15 17:20 3 4"""Support Vector Document Classifier with doc2vec & Optuna 5- .csv label file must be in the form of following style 6|DOCUMENT_FILE_NAME(id)|LABEL(labels)| 7|----------------------|-------------| 8| foo.txt | bar | 9| bar.txt | foo | 10""" 11 12import argparse 13import itertools 14import optuna 15import numpy as np 16import pandas as pd 17import matplotlib.pyplot as plt 18from tqdm import tqdm 19from sklearn.model_selection import train_test_split, cross_val_score 20from sklearn.svm import SVC 21from gensim.models import Doc2Vec 22from sklearn.metrics import confusion_matrix, accuracy_score, f1_score 23 24# parameters 25PATH_TO_CSVFILE = "" 26TEXTFILE_TARGET_DIR = "/" 27PATH_TO_PRETRAINED_DOC2VEC_MODEL = "" 28N_OPTIMIZE_TRIAL = 20 29USE_MORPH_TOKENIZER = False 30 31def plot_confusion_matrix(cm, classes, normalize=False, 32 title='Confusion matrix', cmap=plt.cm.Blues): 33 if normalize: 34 cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 35 print("Normalized confusion matrix") 36 else: 37 print('Confusion matrix, without normalization') 38 39 print(cm) 40 41 plt.imshow(cm, interpolation='nearest', cmap=cmap) 42 plt.title(title) 43 plt.colorbar() 44 tick_marks = np.arange(len(classes)) 45 plt.xticks(tick_marks, classes, rotation=45) 46 plt.yticks(tick_marks, classes) 47 48 fmt = '.2f' if normalize else 'd' 49 thresh = cm.max() / 2. 50 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 51 plt.text(j, i, format(cm[i, j], fmt), 52 horizontalalignment="center", 53 color="white" if cm[i, j] > thresh else "black") 54 55 plt.ylabel('True label') 56 plt.xlabel('Predicted label') 57 plt.tight_layout() 58 return 59 60 61# for Optuna 62def obj(trial): 63 # C 64 svc_c = trial.suggest_loguniform('C', 1e0, 1e2) 65 # kernel 66 kernel = trial.suggest_categorical('kernel', ['linear', 'poly', 'rbf']) 67 # SVC 68 clf = SVC(C=svc_c, kernel=kernel) 69 clf.fit(X_train, y_train) 70 y_pred = clf.predict(X_test) 71 # 3-fold cross validation 72 score = cross_val_score(clf, X_train, y_train, n_jobs=-1, cv=3) 73 accuracy = score.mean() 74 return 1.0 - accuracy 75 76 77if __name__ == "__main__": 78 parser = argparse.ArgumentParser( 79 description='Train a Support Vector Sentence Classifier') 80 parser.add_argument('csv', help='PATH TO CSVFILE') 81 parser.add_argument('dir', help='TEXTFILE TARGET DIRECTORY') 82 parser.add_argument('model', help='PATH TO PRETRAINED DOC2VEC MODEL FILE') 83 parser.add_argument("-N", "--n_trial", dest='n', default=20, type=int, 84 help='N OF OPTIMIZE TRIALS (Default is 20times)') 85 parser.add_argument("-M", "--mecab", dest='mecab', action='store_true', 86 help='USE MECAB Owakati TAGGER') 87 args = parser.parse_args() 88 PATH_TO_CSVFILE = args.csv 89 TEXTFILE_TARGET_DIR = args.dir 90 PATH_TO_PRETRAINED_DOC2VEC_MODEL = args.model 91 N_OPTIMIZE_TRIAL = args.n 92 USE_MORPH_TOKENIZER = args.mecab 93 94 m = MeCab.Tagger("-Owakati") 95 df = pd.read_csv(PATH_TO_CSVFILE) 96 97 documents = [] 98 for fname in tqdm(df.id, desc="Reading Files"): 99 with open(TEXTFILE_TARGET_DIR + fname) as f: 100 if USE_MORPH_TOKENIZER: 101 doc = m.parse(f.read()).strip().split() 102 else: 103 doc = f.read().strip().split() 104 documents.append(doc) 105 106 model = Doc2Vec.load(PATH_TO_PRETRAINED_DOC2VEC_MODEL) 107 document_vectors = [model.infer_vector(s) for s in tqdm(documents)] 108 109 X_train, X_test, y_train, y_test = train_test_split(document_vectors, df.labels, 110 test_size=0.5, random_state=42) 111 112 study = optuna.create_study() 113 study.optimize(obj, n_trials=N_OPTIMIZE_TRIAL) 114 # fits a model with best params 115 clf = SVC(C=study.best_params["C"], kernel=study.best_params["kernel"]) 116 clf.fit(X_train, y_train) 117 y_pred = clf.predict(X_test) 118 # Compute confusion matrix 119 cnf_matrix = confusion_matrix(y_test, y_pred) 120 np.set_printoptions(precision=2) 121 # Plot non-normalized confusion matrix 122 plt.figure() 123 plot_confusion_matrix(cnf_matrix, classes=data.categories, 124 title='Confusion matrix, without normalization') 125 plt.show() 126 # print result 127 print(f"Acc = {accuracy_score(y_test, y_pred)}") 128 print(f"F1 = {f1_score(y_test, y_pred, average='weighted')}")