diff options
author | Iurii Plugatariov <[email protected]> | 2024-07-08 23:56:57 +0200 |
---|---|---|
committer | GitHub <[email protected]> | 2024-07-08 23:56:57 +0200 |
commit | 3335e781107f88646d9403357411303d4f8f0a4f (patch) | |
tree | d4e77fe62c5b0bc87632485f1968fb6c87bc013d /src | |
parent | c7b32cd10370f512bad154d0282f78004ffa4c42 (diff) | |
parent | aafe02ddf1931659b9d42e403d8fcb37450afb43 (diff) | |
download | tinkerbunk-3335e781107f88646d9403357411303d4f8f0a4f.tar.gz |
Merge pull request #3 from makefunstuff/perceptron
Learning simple neural nets just for fun
Diffstat (limited to 'src')
-rw-r--r-- | src/monkey_brain/main.zig | 5 | ||||
-rw-r--r-- | src/monkey_brain/multi_layer_perceptron.zig | 0 | ||||
-rw-r--r-- | src/monkey_brain/perceptron.zig | 83 | ||||
-rw-r--r-- | src/monkey_brain/test.zig | 5 |
4 files changed, 93 insertions, 0 deletions
diff --git a/src/monkey_brain/main.zig b/src/monkey_brain/main.zig new file mode 100644 index 0000000..611dcd1 --- /dev/null +++ b/src/monkey_brain/main.zig @@ -0,0 +1,5 @@ +const perceptron = @import("perceptron.zig"); + +pub fn main() !void { + try perceptron.demo(); +} diff --git a/src/monkey_brain/multi_layer_perceptron.zig b/src/monkey_brain/multi_layer_perceptron.zig new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/monkey_brain/multi_layer_perceptron.zig diff --git a/src/monkey_brain/perceptron.zig b/src/monkey_brain/perceptron.zig new file mode 100644 index 0000000..893bb16 --- /dev/null +++ b/src/monkey_brain/perceptron.zig @@ -0,0 +1,83 @@ +const std = @import("std"); +const testing = std.testing; +const math = std.math; + +const input_size: usize = 2; +const training_set_size: usize = 4; +const learning_rate: f64 = 0.1; +const epochs: u64 = 1000000; + +// https://en.wikipedia.org/wiki/Sigmoid_function - more details +// https://www.youtube.com/watch?v=TPqr8t919YM +fn sigmoid(x: f64) f64 { + return 1.0 / (1.0 + math.exp(-x)); +} + +fn sigmoid_derivative(output: f64) f64 { + return output * (1.0 - output); +} + +fn predict(weights: [input_size]f64, bias: f64, inputs: [input_size]f64) f64 { + var total: f64 = bias; + for (inputs, 0..) |input, i| { + total += weights[i] * input; + } + return sigmoid(total); +} + +fn train(weights: *[input_size]f64, bias: *f64, training_data: [training_set_size][input_size]f64, labels: [training_set_size]f64) void { + for (0..epochs) |_| { + for (training_data, labels) |inputs, label| { + const prediction = predict(weights.*, bias.*, inputs); + const err = label - prediction; + const adjustment = err * sigmoid_derivative(prediction); + for (inputs, 0..) |input, j| { + weights[j] += learning_rate * adjustment * input; + } + bias.* += learning_rate * adjustment; + } + } +} + +pub fn demo() !void { + var weights = [_]f64{ std.crypto.random.float(f64), std.crypto.random.float(f64) }; + var bias: f64 = std.crypto.random.float(f64); + + const training_data = [_][input_size]f64{ + .{ 0, 0 }, + .{ 0, 1 }, + .{ 1, 0 }, + .{ 1, 1 }, + }; + const labels = [_]f64{ 0, 1, 1, 1 }; // OR operation + + train(&weights, &bias, training_data, labels); + + std.debug.print("Trained weights: {d}, {d}\n", .{ weights[0], weights[1] }); + std.debug.print("Trained bias: {d}\n", .{bias}); + + for (training_data, labels) |inputs, expected| { + const prediction = predict(weights, bias, inputs); + std.debug.print("Input: {d}, {d}, Predicted: {d:.4}, Expected: {d}\n", .{ inputs[0], inputs[1], prediction, expected }); + } +} + +test "OR gate" { + var weights = [_]f64{ 0, 0 }; + var bias: f64 = 0; + + const training_data = [_][input_size]f64{ + .{ 0, 0 }, + .{ 0, 1 }, + .{ 1, 0 }, + .{ 1, 1 }, + }; + const labels = [_]f64{ 0, 1, 1, 1 }; + + train(&weights, &bias, training_data, labels); + + for (training_data, labels) |inputs, expected| { + const prediction = predict(weights, bias, inputs); + try testing.expect((prediction - expected) < 0.1); + } +} diff --git a/src/monkey_brain/test.zig b/src/monkey_brain/test.zig new file mode 100644 index 0000000..4d4a04b --- /dev/null +++ b/src/monkey_brain/test.zig @@ -0,0 +1,5 @@ +pub const perceptron = @import("perceptron.zig"); + +test { + @import("std").testing.refAllDecls(@This()); +} |