File tree

6 files changed

+56
-53
lines changed

6 files changed

+56
-53
lines changed
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,15 @@
1717
package chapter7;
1818

1919
import java.util.ArrayList;
20+
import java.util.Arrays;
2021
import java.util.Collections;
2122
import java.util.List;
2223

2324
public class IrisTest {
25+
public static final String IRIS_SETOSA = "Iris-setosa";
26+
public static final String IRIS_VERSICOLOR = "Iris-versicolor";
27+
public static final String IRIS_VIRGINICA = "Iris-virginica";
28+
2429
private List<double[]> irisParameters = new ArrayList<>();
2530
private List<double[]> irisClassifications = new ArrayList<>();
2631
private List<String> irisSpecies = new ArrayList<>();
@@ -32,19 +37,20 @@ public IrisTest() {
3237
Collections.shuffle(irisDataset);
3338
for (String[] iris : irisDataset) {
3439
// first four items are parameters (doubles)
35-
double[] parameters = new double[4];
36-
for (int i = 0; i < parameters.length; i++) {
37-
parameters[i] = Double.parseDouble(iris[i]);
38-
}
40+
double[] parameters = Arrays.stream(iris)
41+
.limit(4)
42+
.mapToDouble(Double::parseDouble)
43+
.toArray();
3944
irisParameters.add(parameters);
4045
// last item is species
4146
String species = iris[4];
42-
if (species.equals("Iris-setosa")) {
43-
irisClassifications.add(new double[] { 1.0, 0.0, 0.0 });
44-
} else if (species.equals("Iris-versicolor")) {
45-
irisClassifications.add(new double[] { 0.0, 1.0, 0.0 });
46-
} else { // Iris-virginica
47-
irisClassifications.add(new double[] { 0.0, 0.0, 1.0 });
47+
switch (species) {
48+
case IRIS_SETOSA :
49+
irisClassifications.add(new double[] { 1.0, 0.0, 0.0 }); break;
50+
case IRIS_VERSICOLOR :
51+
irisClassifications.add(new double[] { 0.0, 1.0, 0.0 }); break;
52+
default :
53+
irisClassifications.add(new double[] { 0.0, 0.0, 1.0 }); break;
4854
}
4955
irisSpecies.add(species);
5056
}
@@ -54,12 +60,12 @@ public IrisTest() {
5460
public String irisInterpretOutput(double[] output) {
5561
double max = Util.max(output);
5662
if (max == output[0]) {
57-
return "Iris-setosa";
58-
} else if (max == output[1]) {
59-
return "Iris-versicolor";
60-
} else {
61-
return "Iris-virginica";
63+
return IRIS_SETOSA;
64+
}
65+
if (max == output[1]) {
66+
return IRIS_VERSICOLOR;
6267
}
68+
return IRIS_VIRGINICA;
6369
}
6470

6571
public Network<String>.Results classify() {
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,16 @@
2424

2525
public class Layer {
2626
public Optional<Layer> previousLayer;
27-
public List<Neuron> neurons;
27+
public List<Neuron> neurons = new ArrayList<>();
2828
public double[] outputCache;
2929

3030
public Layer(Optional<Layer> previousLayer, int numNeurons, double learningRate,
3131
DoubleUnaryOperator activationFunction, DoubleUnaryOperator derivativeActivationFunction) {
3232
this.previousLayer = previousLayer;
33-
neurons = new ArrayList<>();
33+
Random random = new Random();
3434
for (int i = 0; i < numNeurons; i++) {
3535
double[] randomWeights = null;
3636
if (previousLayer.isPresent()) {
37-
Random random = new Random();
3837
randomWeights = random.doubles(previousLayer.get().neurons.size()).toArray();
3938
}
4039
Neuron neuron = new Neuron(randomWeights, learningRate, activationFunction, derivativeActivationFunction);
@@ -63,7 +62,7 @@ public void calculateDeltasForOutputLayer(double[] expected) {
6362
// should not be called on output layer
6463
public void calculateDeltasForHiddenLayer(Layer nextLayer) {
6564
for (int i = 0; i < neurons.size(); i++) {
66-
final int index = i;
65+
int index = i;
6766
double[] nextWeights = nextLayer.neurons.stream().mapToDouble(n -> n.weights[index]).toArray();
6867
double[] nextDeltas = nextLayer.neurons.stream().mapToDouble(n -> n.delta).toArray();
6968
double sumWeightsAndDeltas = Util.dotProduct(nextWeights, nextDeltas);
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,13 @@
2323
import java.util.function.Function;
2424

2525
public class Network<T> {
26-
private List<Layer> layers;
26+
private List<Layer> layers = new ArrayList<>();
2727

2828
public Network(int[] layerStructure, double learningRate,
2929
DoubleUnaryOperator activationFunction, DoubleUnaryOperator derivativeActivationFunction) {
3030
if (layerStructure.length < 3) {
3131
throw new IllegalArgumentException("Error: Should be at least 3 layers (1 input, 1 hidden, 1 output).");
3232
}
33-
layers = new ArrayList<>();
3433
// input layer
3534
Layer inputLayer = new Layer(Optional.empty(), layerStructure[0], learningRate, activationFunction,
3635
derivativeActivationFunction);
@@ -47,11 +46,7 @@ public Network(int[] layerStructure, double learningRate,
4746
// Pushes input data to the first layer, then output from the first
4847
// as input to the second, second to the third, etc.
4948
private double[] outputs(double[] input) {
50-
double[] result = input;
51-
for (Layer layer : layers) {
52-
result = layer.outputs(result);
53-
}
54-
return result;
49+
return layers.stream().reduce(input, (r, l) -> l.outputs(r), (r1, r2) -> r1);
5550
}
5651

5752
// Figure out each neuron's changes based on the errors of the output
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
package chapter7;
1818

1919
import java.io.BufferedReader;
20+
import java.io.IOException;
2021
import java.io.InputStream;
2122
import java.io.InputStreamReader;
2223
import java.util.ArrayList;
24+
import java.util.Arrays;
2325
import java.util.Collections;
2426
import java.util.List;
2527
import java.util.stream.Collectors;
@@ -63,22 +65,23 @@ public static void normalizeByFeatureScaling(List<double[]> dataset) {
6365

6466
// Load a CSV file into a List of String arrays
6567
public static List<String[]> loadCSV(String filename) {
66-
InputStream inputStream = Util.class.getResourceAsStream(filename);
67-
InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
68-
BufferedReader bufferedReader = new BufferedReader(inputStreamReader);
69-
return bufferedReader.lines().map(line -> line.split(","))
70-
.collect(Collectors.toList());
68+
try (InputStream inputStream = Util.class.getResourceAsStream(filename)) {
69+
InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
70+
BufferedReader bufferedReader = new BufferedReader(inputStreamReader);
71+
return bufferedReader.lines().map(line -> line.split(","))
72+
.collect(Collectors.toList());
73+
}
74+
catch (IOException e) {
75+
e.printStackTrace();
76+
throw new RuntimeException(e.getMessage(), e);
77+
}
7178
}
7279

7380
// Find the maximum in an array of doubles
7481
public static double max(double[] numbers) {
75-
double m = Double.MIN_VALUE;
76-
for (double number : numbers) {
77-
if (number > m) {
78-
m = number;
79-
}
80-
}
81-
return m;
82+
return Arrays.stream(numbers)
83+
.max()
84+
.orElse(Double.MIN_VALUE);
8285
}
8386

8487
}
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package chapter7;
1818

1919
import java.util.ArrayList;
20+
import java.util.Arrays;
2021
import java.util.Collections;
2122
import java.util.List;
2223

@@ -32,19 +33,20 @@ public WineTest() {
3233
Collections.shuffle(wineDataset);
3334
for (String[] wine : wineDataset) {
3435
// last thirteen items are parameters (doubles)
35-
double[] parameters = new double[13];
36-
for (int i = 1; i < (parameters.length + 1); i++) {
37-
parameters[i - 1] = Double.parseDouble(wine[i]);
38-
}
36+
double[] parameters = Arrays.stream(wine)
37+
.skip(1)
38+
.mapToDouble(Double::parseDouble)
39+
.toArray();
3940
wineParameters.add(parameters);
4041
// first item is species
4142
int species = Integer.parseInt(wine[0]);
42-
if (species == 1) {
43-
wineClassifications.add(new double[] { 1.0, 0.0, 0.0 });
44-
} else if (species == 2) {
45-
wineClassifications.add(new double[] { 0.0, 1.0, 0.0 });
46-
} else { // 3
47-
wineClassifications.add(new double[] { 0.0, 0.0, 1.0 });
43+
switch (species) {
44+
case 1 :
45+
wineClassifications.add(new double[] { 1.0, 0.0, 0.0 }); break;
46+
case 2 :
47+
wineClassifications.add(new double[] { 0.0, 1.0, 0.0 }); break;
48+
default :
49+
wineClassifications.add(new double[] { 0.0, 0.0, 1.0 });; break;
4850
}
4951
wineSpecies.add(species);
5052
}
@@ -55,11 +57,11 @@ public Integer wineInterpretOutput(double[] output) {
5557
double max = Util.max(output);
5658
if (max == output[0]) {
5759
return 1;
58-
} else if (max == output[1]) {
60+
}
61+
if (max == output[1]) {
5962
return 2;
60-
} else {
61-
return 3;
6263
}
64+
return 3;
6365
}
6466

6567
public Network<Integer>.Results classify() {
This file was deleted.

0 commit comments

Comments
 (0)