Skip to content

Commit 8cd24f3

Browse files
authored
svm model implemetation (#43)
* svm model implemetation * add classification report to svm tuning * hotfix
1 parent 41c68fc commit 8cd24f3

3 files changed

Lines changed: 1359 additions & 0 deletions

File tree

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"id": "25921d5f-e3b7-4960-ad21-95cad70e6e53",
7+
"metadata": {
8+
"tags": []
9+
},
10+
"outputs": [],
11+
"source": [
12+
"import numpy as np\n",
13+
"import pandas as pd\n",
14+
"from sklearn.linear_model import SGDOneClassSVM\n",
15+
"from sklearn.metrics import (\n",
16+
" classification_report,\n",
17+
" precision_score,\n",
18+
" recall_score,\n",
19+
" f1_score,\n",
20+
" accuracy_score,\n",
21+
" roc_auc_score,\n",
22+
" fbeta_score,\n",
23+
")\n",
24+
"from sklearn.kernel_approximation import RBFSampler\n",
25+
"from itertools import product\n",
26+
"import csv"
27+
]
28+
},
29+
{
30+
"cell_type": "code",
31+
"execution_count": 3,
32+
"id": "1a264aca-9ea2-4ff8-b1fc-ca619e83bb59",
33+
"metadata": {
34+
"tags": []
35+
},
36+
"outputs": [],
37+
"source": [
38+
"X_train = pd.read_parquet(\"data\").to_numpy(dtype=float)\n",
39+
"validation_data = pd.read_parquet(\"benchmark\")"
40+
]
41+
},
42+
{
43+
"cell_type": "code",
44+
"execution_count": 4,
45+
"id": "0caae7d5-143d-4ca1-a06f-b766af2a9c7d",
46+
"metadata": {
47+
"tags": []
48+
},
49+
"outputs": [],
50+
"source": [
51+
"X_val = validation_data.drop(columns=['label']).to_numpy(dtype=float)\n",
52+
"y_val = validation_data['label'].values"
53+
]
54+
},
55+
{
56+
"cell_type": "code",
57+
"execution_count": 5,
58+
"id": "9fc0412f-0c32-4b8c-a90b-52b4c1409527",
59+
"metadata": {
60+
"tags": []
61+
},
62+
"outputs": [],
63+
"source": [
64+
"y_val = np.where(y_val, -1, 1)"
65+
]
66+
},
67+
{
68+
"cell_type": "code",
69+
"execution_count": 8,
70+
"id": "4dbf3113-58cb-4c28-ac90-418025a6f84e",
71+
"metadata": {
72+
"tags": []
73+
},
74+
"outputs": [],
75+
"source": [
76+
"gamma_values = np.arange(0.1, 10.0, 0.1).tolist()\n",
77+
"\n",
78+
"param_values = {\"gamma\": [0.1, 1.0, 5.0, 10.0], \"nu\": [0.01, 0.05, 0.1, 0.2], \"tol\": [1e-7], \"eta0\": [1e-6]}\n",
79+
"\n",
80+
"keys = param_values.keys()\n",
81+
"values = param_values.values()\n",
82+
"combinations = [dict(zip(keys, combo)) for combo in product(*values)]"
83+
]
84+
},
85+
{
86+
"cell_type": "code",
87+
"execution_count": 9,
88+
"id": "d852cff0-d44f-4b52-8803-e109cb22af58",
89+
"metadata": {
90+
"tags": []
91+
},
92+
"outputs": [
93+
{
94+
"name": "stdout",
95+
"output_type": "stream",
96+
"text": [
97+
"1\n",
98+
"2\n",
99+
"3\n",
100+
"4\n",
101+
"5\n",
102+
"6\n",
103+
"7\n",
104+
"8\n",
105+
"9\n",
106+
"10\n",
107+
"11\n",
108+
"12\n",
109+
"13\n",
110+
"14\n",
111+
"15\n",
112+
"16\n"
113+
]
114+
}
115+
],
116+
"source": [
117+
"csv_file = \"svm_results_tuning.csv\"\n",
118+
"\n",
119+
"models = []\n",
120+
"\n",
121+
"with open(csv_file, mode=\"w\", newline=\"\") as file:\n",
122+
" writer = csv.writer(file)\n",
123+
" writer.writerow([\"gamma\", \"nu\", \"tol\", \"eta0\", \"Precision\", \"Recall Normal\", \"Recall Anomaly\", \"F1-Score\", \"Accuracy\", \"AUC\", \"F2-Score\"])\n",
124+
"\n",
125+
" for i, combo in enumerate(combinations, start=1):\n",
126+
" nystroem = RBFSampler(\n",
127+
" gamma=combo[\"gamma\"], \n",
128+
" n_components=1000,\n",
129+
" random_state=42\n",
130+
" )\n",
131+
"\n",
132+
" sgd_ocsvm = SGDOneClassSVM(\n",
133+
" nu=combo[\"nu\"],\n",
134+
" shuffle=True,\n",
135+
" learning_rate = 'constant',\n",
136+
" tol=combo[\"tol\"],\n",
137+
" random_state=42,\n",
138+
" eta0=combo[\"eta0\"],\n",
139+
" max_iter=10000\n",
140+
" )\n",
141+
"\n",
142+
" X_batch_transformed = nystroem.fit_transform(X_val)\n",
143+
"\n",
144+
" sgd_ocsvm.fit(X_batch_transformed)\n",
145+
"\n",
146+
" X_val_transformed = nystroem.transform(X_val)\n",
147+
" y_pred = sgd_ocsvm.predict(X_val_transformed)\n",
148+
"\n",
149+
" accuracy = accuracy_score(y_val, y_pred)\n",
150+
" precision = precision_score(y_val, y_pred, pos_label=1)\n",
151+
" recall_normal = recall_score(y_val, y_pred, pos_label=1)\n",
152+
" recall_anomaly = recall_score(y_val, y_pred, pos_label=-1)\n",
153+
" f1 = f1_score(y_val, y_pred, pos_label=1)\n",
154+
" auc = roc_auc_score(y_val, y_pred)\n",
155+
" f2 = fbeta_score(y_val, y_pred, beta=2, pos_label=1)\n",
156+
"\n",
157+
" print(i)\n",
158+
" models.append({\"auc\": auc, \"y_pred\": y_pred})\n",
159+
"\n",
160+
" writer.writerow([\n",
161+
" combo[\"gamma\"], combo[\"nu\"], combo[\"tol\"], combo[\"eta0\"],\n",
162+
" precision, recall_normal, recall_anomaly, f1, accuracy, auc, f2\n",
163+
" ])"
164+
]
165+
},
166+
{
167+
"cell_type": "code",
168+
"execution_count": 10,
169+
"id": "5b454121",
170+
"metadata": {},
171+
"outputs": [
172+
{
173+
"name": "stdout",
174+
"output_type": "stream",
175+
"text": [
176+
" precision recall f1-score support\n",
177+
"\n",
178+
" Anomalous 0.22 0.29 0.25 14034\n",
179+
" Normal 0.79 0.72 0.75 50702\n",
180+
"\n",
181+
" accuracy 0.62 64736\n",
182+
" macro avg 0.50 0.50 0.50 64736\n",
183+
"weighted avg 0.66 0.62 0.64 64736\n",
184+
"\n"
185+
]
186+
}
187+
],
188+
"source": [
189+
"highest_auc_item = max(models, key=lambda x: x[\"auc\"])\n",
190+
"print(classification_report(y_val, y_pred, target_names=[\"Anomalous\", \"Normal\"]))"
191+
]
192+
}
193+
],
194+
"metadata": {
195+
"kernelspec": {
196+
"display_name": "venv",
197+
"language": "python",
198+
"name": "python3"
199+
},
200+
"language_info": {
201+
"codemirror_mode": {
202+
"name": "ipython",
203+
"version": 3
204+
},
205+
"file_extension": ".py",
206+
"mimetype": "text/x-python",
207+
"name": "python",
208+
"nbconvert_exporter": "python",
209+
"pygments_lexer": "ipython3",
210+
"version": "3.11.2"
211+
}
212+
},
213+
"nbformat": 4,
214+
"nbformat_minor": 5
215+
}

0 commit comments

Comments
 (0)