Machine Learning mit Smile
Di 20 September 2016 by Oliver PaetzelTags: Java Smile Machine Learning wine
Wir beginnen mit einem leeren Java-Projekt. Für die Klassifikation werden wir eine Support Vector Machine (SVM) aus dem Smile framework benutzen.
Dafür laden wir zunächst die smile-core
Komponenten herunter. Dies kann man ganz einfach über maven machen:
<dependencies>
<dependency>
<groupId>com.github.haifengl</groupId>
<artifactId>smile-core</artifactId>
<version>1.2.0</version>
</dependency>
</dependencies>
Wenn man maven nicht mag, kann man sich die jars auch einzeln besorgen. Wir werden folgende jars benötigen:
Trainingsdaten
Als Trainingsdaten nehmen wir das wine-Datenset aus dem UCI Machine Learning Repository: Wine-Daten. Dieses Datenset enthält Informationen von Weinen, die alle in der gleichen Region in Italien gewachsen sind, jedoch von drei verschiedenen Rebsorten abstammen. Die Daten liegen im CSV-Format vor und enthalten als ersten Wert die Klasse (sozusagen die Rebsorte) des jeweiligen Beispiels. Die restlichen 13 Werte sind Werte aus einer chemischen Analyse der Weine. Unser Klassifikator soll nun anhand der Features vorhersagen, welche Rebsorte der jeweilige Wein hat.
Einlesen der Daten
Die Klassifikatoren in Smile haben alle ein ähnliches Eingabeformat. Wir benötigen die Klassen als Integer und die Features als double-Array. Da wir (normalerweise) vorher nicht wissen, wie viele Beispiele wir haben, bietet sich hier eine Liste an, die später in ein Array umgewandelt wird:
List<Integer> classes = new ArrayList<>();
List<double[]> features = new ArrayList<>();
Die features unserer Weine lassen sich dann mit etwas Java 8 Syntax recht einfach auslesen:
try (Stream<String> stream = Files.lines(Paths.get("/some/path/to/wine.data"))) {
stream.forEach((s) -> {
String[] vals = s.split(",");
classes.add(Integer.parseInt(vals[0])-1);
double[] instanceFeatures = new double[vals.length-1];
for(int i=1; i<vals.length;i++) {
instanceFeatures[i-1] = Double.parseDouble(vals[i]);
}
features.add(instanceFeatures);
});
}
Nun noch schnell in das gewünschte Input-Format unseres Classifiers umwandeln:
double[][] featureArray = features.stream().toArray(double[][]::new);
int[] classArray = classes.stream().mapToInt(i->i).toArray();
Klassifikation
Die eigentliche Klassifikation geht dann wenn man die Features erst mal hat dank des frameworks recht leicht von der Hand. Da wir direkt eine Cross-Validation durchführen wollen, benötigen wir einen ClassifierTrainer
, der die SVMs für die Cross-Validation trainiert:
ClassifierTrainer<double[]> trainer = new SVM.Trainer<>(new GaussianKernel(3), 10, 3, Multiclass.ONE_VS_ALL);
Da unsere Daten drei Klassen enthalten, müssen wir eine MultiClass SVM benutzen. Ich habe mich hier für die Variante one vs all entschieden, da diese schneller klassifiziert (weniger Durchgänge als bei one vs one).
Damit wir bei unserer Cross-Validation auch Ergebnisse sehen, müssen wir ein array aus ClassificationMeasure
-Instanzen erstellen. Dieses wird dann an die Cross-Validation Methode als Parameter übergeben:
ClassificationMeasure[] measures = new ClassificationMeasure[]{new Accuracy(), new Precision(), new Recall(), new Fallout(), new FMeasure()};
double[] results = Validation.cv(10, trainer, featureArray, classArray, measures);
for(int i=0;i<results.length;i++) {
System.out.println(measures[i].getClass().getSimpleName() + ": " + results[i]);
}
Und schon haben wir klassifiziert! Die Ergebnisse sind allerdings recht ernüchternd:
Accuracy: 0.398876404494382
Precision: 0.398876404494382
Recall: 1.0
Fallout: 1.0
FMeasure: 0.570281124497992
Normalisierung
Das liegt daran, dass wir vergessen haben zu normalisieren. SVMs (und auch einige andere Klassifikatoren), erwarten Werte im Intervall [0;1]
. Wir normalisieren allerdings nicht hart in dieses Intervall, sondern wählen eine andere Variante der Normalisierung:
Wir ziehen von jedem Feature den Mittelwert ab und teilen dann durch die Standardabweichung. Das Ganze muss per Feature geschehen, über alle Instanzen hinweg. Hier der Code:
//berechnen der Mittelwerte
double[] means = new double[13];
for(double[] fInstance : features) {
for(int i=0;i<fInstance.length;i++) {
means[i] += fInstance[i];
}
}
for(int i=0;i<means.length;i++) {
means[i] = means[i] / features.size();
}
//berechnen der Varianzen
double[] variances = new double[13];
for(double[] fInstance : features) {
for(int i=0;i<fInstance.length;i++) {
variances[i] += Math.pow(fInstance[i]-means[i], 2);
}
}
for(int i=0;i<variances.length;i++) {
variances[i] = variances[i] / features.size();
}
//die eigentliche Normalisierung
double[][] featureArray = new double[features.size()][];
for(int i=0;i<features.size();i++) {
double[] toNormalize = features.get(i);
double[] normalized = new double[features.get(i).length];
for(int k=0;k<features.get(i).length;k++) {
normalized[k] = (toNormalize[k]-means[k])/Math.sqrt(variances[k]);
}
featureArray[i] = normalized;
}
Hier gibt es elegantere Möglichkeiten zur Berechnung, im Sinne der Verständlichkeit habe ich hier aber einfach die Schulformeln im Code umgesetzt.
Mit den normalisierten Werten sind die Ergebnisse nun um einiges erbaulicher:
Accuracy: 0.9887640449438202
Precision: 0.9859154929577465
Recall: 0.9859154929577465
Fallout: 0.009345794392523366
FMeasure: 0.9859154929577465
Kernel und Parameter
Damit eine SVM im realen Einsatz gute Ergebnisse liefert, müssen immer der Parameter C
und die jeweiligen Parameter im Kernel optimiert werden. Das geschieht meist durch ausprobieren in 10er-Potenzen (also z.B. 0.001->0.01->0.1->1->10->100).
Der meist genutzte Kernel ist hier sicher der Gauß-Kernel, den wir auch oben benutzt haben. Ein linearer Kernel bietet jedoch den Vorteil, dass man leichter (bzw. überhaupt) visualisieren kann was gelernt wurde.
Die optimierte SVM trainieren und ausliefern
Wenn alle Parameter optimiert und wir mit den Ergebnissen zufrieden sind, wird es Zeit die SVM auf allen Trainingsdaten zu trainieren und auszuliefern. Dafür benutzen wir nun direkt die Klasse SVM
:
SVM<double[]> svm = new SVM<>(new GaussianKernel(3), 10, 3, Multiclass.ONE_VS_ALL);
Als Parameter nehmen wir hier natürlich die vorher optimierten Werte. Trainiert wird die SVM dann folgendermaßen:
svm.learn(featureArray, classArray);
Dieses Objekt kann man dann serialisieren (z.B. mit xstream oder Gson), und später wieder einlesen.
Wenn man das SVM-Objekt dann erst mal hat, kann man mit ihm Vorhersagen für andere Datensätze treffen. Die Datensätze, für die die Klasse dann vorhergesagt werden soll, müssen vorher auf jeden Fall mit den gleichen Werten normalisiert werden, mit denen auch die Trainingsdaten normalisiert wurden:
double[] newExample = new double[]{12.77,2.39,2.28,19.5,86,1.39,.51,.48,.64,9.899999,.57,1.63,470};
for(int k=0;k<newExample.length;k++) {
newExample[k] = (newExample[k]-means[k])/Math.sqrt(variances[k]);
}
svm.predict(newExample); //Ergebnis: 2
Die Werte für die Mittelwerte und die Varianzen sollten also gemeinsam mit dem SVM-Objekt serialisiert werden.
Als ich eine der SQL-Queries für das e-learning Portal geschrieben habe, kam darin auch ein JOIN
-Statement vor. Ich habe mich also innerlich auf das 'manuelle' zusammenfassen der Ergebnisse im Code mittels einer map[string]string
vorbereitet, da ist
mir durch Zufall die postgresql json_agg
aggregate function begegnet (postgresql aggregate functions). Die query würde dann folgendermaßen aussehen:
SELECT units.*, json_agg(pages.page_id) AS pages_arr FROM units LEFT OUTER JOIN pages ON ... GROUP BY units.unit_id
Der pages_arr
-Wert wird dann als JSON-Array zurückgeliefert, das einfach mit einem beliebigen JSON-Parser in der jeweiligen Programmiersprache in ein Array oder eine Liste verwandelt werden kann.
Sehr praktisch wie ich finde!
Neue Serie: Web-App mit go/REST/postgres/JWT/emberjs
Mi 10 August 2016 by Oliver PaetzelTags: golang REST postgres sql JWT emberjs
Für ein Uni-Projekt habe ich die Aufgabe, eine E-Learning Applikation zu entwerfen. Diese soll zunächst als Web-App realisiert werden, die Möglichkeit native Apps für mobile Endgeräte nachzuentwickeln, soll allerdings auch vorhanden sein.
Für diese Anforderungen eignet sich ganz besonders die Kombination (REST)-api im backend und JavaScript MVC Framework für das frontend. Da ich normalerweise Web-Anwendungen mit Java/MySQL/JSF entwickle, wäre zumindest schon mal das frontend neu für mich gewesen. Um das Ganze noch interessanter zu machen, habe ich zusätzlich noch die Programmiersprache im Backend und die Datenbank geändert.
Als Programmiersprache für das backend werde ich go von Google benutzen und als Datenbank postgresql. Interessant an go ist vor allem die Art des deployens: Es wird eine komplett eigenständige binary mit Runtime und Garbage Collector erstellt. Beim jetzigen Projektstand ist diese lediglich 7,3MB groß (kompiliert mit go 1.7rc6). Es gibt noch viele weitere Dinge, die go anders macht als andere Programmiersprachen, doch für mich ist dies der interessanteste Unterschied zu Java, wo die "fat-WARs" schon mal 30-40 MB groß sein können mit allen Abhängigkeiten.
Für das frontend habe ich mich zunächst für emberjs entschieden. Da ich noch nicht mit der frontend-Entwicklung begonnen habe, kann ich hierzu noch nicht so viel schreiben, außer dass mir die Tutorials besser gefallen haben als die von Angular. Bei emberjs scheint mir alles etwas klarer und die Struktur der Applikation mehr vorgegeben zu sein.
Die Blog-Serie soll keine minutiöse Beschreibung der Entwicklung werden, sondern vielmehr interessante Entdeckungen und Entscheidungen festhalten.
JavaFX SceneBuilder unter Debian Jessie
Anleitung zur Installation von JavaFX SceneBuilder in Debian Jessie
weiterlesenDolphin Emulator in Debian Jessie kompilieren
Anleitung zum Bauen des Dolphin GameCube und Wii Emulator in Debian Jessie.
weiterlesenLilyPond "Wiederholung unter der Klammer"
Ich wollte eine Trompetenstimme mit LilyPond transponieren. In dem Stück gab es aber unter der "Klammer 2" ein weiteres Wiederholungszeichen. Ich beschreibe wie ich dies in LilyPond umgesetzt habe.
weiterlesenJava substring() Performance
Tags: java substring performance string java.lang charsequence
Ein Blick auf die substring() Implementierung in der openJDK Standardbibliothek und warum sie für mich ungeeignet war.
weiterlesenlftp und SSL
Konfiguration von lftp zum syncen der pelican-Seite auf den Server mit erzwungenem SSL
weiterlesenErster Post
Mein erster Blogeintrag. Hier geht es um die Einrichtung dieses Blogs mit pelican.
weiterlesen