about summary refs log tree commit diff
path: root/src/monkey_brain/perceptron.zig
diff options
context:
space:
mode:
authormakefunstuff <[email protected]>2024-07-11 00:05:11 +0200
committermakefunstuff <[email protected]>2024-07-11 00:05:11 +0200
commitfd3a727f4b4f819178225cdd87d8788b85c4b86c (patch)
tree58c6ee139f7768ec739477354af0faf001b2d40d /src/monkey_brain/perceptron.zig
parente5a064e82fe7d36f4423abdf798f9f1f0a90ae0e (diff)
downloadtinkerbunk-fd3a727f4b4f819178225cdd87d8788b85c4b86c.tar.gz
some fixes and dummy neuron
Diffstat (limited to '')
-rw-r--r--src/monkey_brain/perceptron.zig16
1 files changed, 10 insertions, 6 deletions
diff --git a/src/monkey_brain/perceptron.zig b/src/monkey_brain/perceptron.zig
index 893bb16..1d819d1 100644
--- a/src/monkey_brain/perceptron.zig
+++ b/src/monkey_brain/perceptron.zig
@@ -5,7 +5,7 @@ const math = std.math;
 const input_size: usize = 2;
 const training_set_size: usize = 4;
 const learning_rate: f64 = 0.1;
-const epochs: u64 = 1000000;
+const epochs: u64 = 10000;
 
 // https://en.wikipedia.org/wiki/Sigmoid_function - more details
 // https://www.youtube.com/watch?v=TPqr8t919YM
@@ -40,8 +40,8 @@ fn train(weights: *[input_size]f64, bias: *f64, training_data: [training_set_siz
 }
 
 pub fn demo() !void {
-    var weights = [_]f64{ std.crypto.random.float(f64), std.crypto.random.float(f64) };
-    var bias: f64 = std.crypto.random.float(f64);
+    var weights = [_]f64{ std.crypto.random.float(f64) * 2 - 1, std.crypto.random.float(f64) * 2 - 1 };
+    var bias: f64 = std.crypto.random.float(f64) * 2 - 1;
 
     const training_data = [_][input_size]f64{
         .{ 0, 0 },
@@ -63,8 +63,8 @@ pub fn demo() !void {
 }
 
 test "OR gate" {
-    var weights = [_]f64{ 0, 0 };
-    var bias: f64 = 0;
+    var weights = [_]f64{ 0.3, 0.2 };
+    var bias: f64 = 0.5;
 
     const training_data = [_][input_size]f64{
         .{ 0, 0 },
@@ -78,6 +78,10 @@ test "OR gate" {
 
     for (training_data, labels) |inputs, expected| {
         const prediction = predict(weights, bias, inputs);
-        try testing.expect((prediction - expected) < 0.1);
+        const predicted_error = prediction - expected;
+        std.debug.print("Predicted error {}\n", .{predicted_error});
+        std.debug.print("Predicted: {} | Expected: {}\n", .{ prediction, expected });
+
+        try testing.expect(predicted_error < 0.1);
     }
 }