about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/monkey_brain/main.zig69
1 files changed, 44 insertions, 25 deletions
diff --git a/src/monkey_brain/main.zig b/src/monkey_brain/main.zig
index 52ff9e5..5d0e616 100644
--- a/src/monkey_brain/main.zig
+++ b/src/monkey_brain/main.zig
@@ -1,13 +1,14 @@
 const std = @import("std");
 const testing = std.testing;
+const math = std.math;
 
 const input_size: usize = 2;
 const training_set_size: usize = 4;
 const learning_rate: f64 = 0.1;
-const epochs: u64 = 100 * 1000;
+const epochs: u64 = 1000000;
 
 fn sigmoid(x: f64) f64 {
-    return 1.0 / (1.0 + std.math.exp(-x));
+    return 1.0 / (1.0 + math.exp(-x));
 }
 
 fn sigmoid_derivative(output: f64) f64 {
@@ -15,23 +16,21 @@ fn sigmoid_derivative(output: f64) f64 {
 }
 
 fn predict(weights: [input_size]f64, bias: f64, inputs: [input_size]f64) f64 {
-    var total: f64 = 0.0;
-    for (0..input_size) |i| {
-        total += weights[i] * inputs[i];
+    var total: f64 = bias;
+    for (inputs, 0..) |input, i| {
+        total += weights[i] * input;
     }
-    total += bias;
     return sigmoid(total);
 }
 
 fn train(weights: *[input_size]f64, bias: *f64, training_data: [training_set_size][input_size]f64, labels: [training_set_size]f64) void {
     for (0..epochs) |_| {
-        for (0..training_set_size) |i| {
-            const prediction = predict(weights.*, bias.*, training_data[i]);
-            const err = labels[i] - prediction;
+        for (training_data, labels) |inputs, label| {
+            const prediction = predict(weights.*, bias.*, inputs);
+            const err = label - prediction;
             const adjustment = err * sigmoid_derivative(prediction);
-
-            for (0..input_size) |j| {
-                weights[j] += learning_rate * adjustment * training_data[i][j];
+            for (inputs, 0..) |input, j| {
+                weights[j] += learning_rate * adjustment * input;
             }
             bias.* += learning_rate * adjustment;
         }
@@ -39,24 +38,44 @@ fn train(weights: *[input_size]f64, bias: *f64, training_data: [training_set_siz
 }
 
 pub fn main() !void {
-    const w1 = std.crypto.random.float(f64);
-    const w2 = std.crypto.random.float(f64);
-
-    var weights: [input_size]f64 = .{ w1, w2 };
-    var bias: f64 = 0.0;
-
-    const training_data: [training_set_size][input_size]f64 = .{ .{ 0, 0 }, .{ 0, 1 }, .{ 1, 0 }, .{ 1, 1 } };
+    var weights = [_]f64{ std.crypto.random.float(f64), std.crypto.random.float(f64) };
+    var bias: f64 = std.crypto.random.float(f64);
 
-    const labels: [training_set_size]f64 = .{ 0, 0, 0, 1 };
+    const training_data = [_][input_size]f64{
+        .{ 0, 0 },
+        .{ 0, 1 },
+        .{ 1, 0 },
+        .{ 1, 1 },
+    };
+    const labels = [_]f64{ 0, 1, 1, 1 }; // OR operation
 
     train(&weights, &bias, training_data, labels);
 
-    for (0..training_set_size) |i| {
-        const prediction = predict(weights, bias, training_data[i]);
-        std.log.info("Input {} {}, Predicted output: {}", .{ training_data[i][0], training_data[i][1], prediction });
+    std.debug.print("Trained weights: {d}, {d}\n", .{ weights[0], weights[1] });
+    std.debug.print("Trained bias: {d}\n", .{bias});
+
+    for (training_data, labels) |inputs, expected| {
+        const prediction = predict(weights, bias, inputs);
+        std.debug.print("Input: {d}, {d}, Predicted: {d:.4}, Expected: {d}\n", .{ inputs[0], inputs[1], prediction, expected });
     }
 }
 
-test "hello" {
-    try testing.expect(true);
+test "OR gate" {
+    var weights = [_]f64{ 0, 0 };
+    var bias: f64 = 0;
+
+    const training_data = [_][input_size]f64{
+        .{ 0, 0 },
+        .{ 0, 1 },
+        .{ 1, 0 },
+        .{ 1, 1 },
+    };
+    const labels = [_]f64{ 0, 1, 1, 1 };
+
+    train(&weights, &bias, training_data, labels);
+
+    for (training_data, labels) |inputs, expected| {
+        const prediction = predict(weights, bias, inputs);
+        try testing.expect((prediction - expected) < 0.1);
+    }
 }