about summary refs log tree commit diff
path: root/src/monkey_learns/linear_regression.zig
blob: 4cc5a9f043c33483a9f0df7c6f995cb0ef9ac920 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
const std = @import("std");
const testing = std.testing;
const math = std.math;

const LinearRegression = struct {
    const Self = @This();

    weight: f64,
    bias: f64,

    fn init() LinearRegression {
        return Self{ .weight = 0.0, .bias = 0.0 };
    }

    fn predict(self: Self, x: f64) f64 {
        return self.weight * x + self.bias;
    }

    fn train(self: *LinearRegression, x: []const f64, y: []const f64, learning_rate: f64, epochs: usize) void {
        const n: f64 = @floatFromInt(x.len);

        for (0..epochs) |epoch| {
            var total_error: f64 = 0;

            for (x, y) |xi, yi| {
                const prediction = self.predict(xi);
                const err = prediction - yi;

                self.weight -= learning_rate * err * xi;
                self.bias -= learning_rate * err;

                total_error += err * err;
            }

            const current_mse = total_error / n;

            if (epoch % 1000 == 0 or epoch == epochs - 1) {
                std.debug.print("Epoch {d}: MSE = {d:.6}\n", .{ epoch, current_mse });
            }
        }
    }

    // https://en.wikipedia.org/wiki/Mean_squared_error
    fn loss_function(self: Self, x: []const f64, y: []const f64) f64 {
        var squared_sum: f64 = 0.0;
        for (x, y) |xi, yi| {
            const predicted = self.predict(xi);
            const err = predicted - yi;
            squared_sum += err * err;
        }
        const n: f64 = @floatFromInt(x.len);
        return squared_sum / n;
    }
};

test "Linear Regression" {
    // Initialize the model
    var model = LinearRegression.init();

    const x = [_]f64{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
    const y = [_]f64{ 3.1, 4.9, 7.2, 9.1, 11.0, 12.8, 14.9, 17.2, 18.8, 21.1 };

    const learning_rate: f64 = 0.01;
    const epochs: usize = 1000;
    model.train(&x, &y, learning_rate, epochs);

    try testing.expect(@abs(model.weight - 2.0) < 0.1);
    try testing.expect(@abs(model.bias - 1.0) < 0.1);

    const test_x = [_]f64{ 0, 5, 10 };
    const expected_y = [_]f64{ 1, 11, 21 };
    for (test_x, expected_y) |xi, yi| {
        const prediction = model.predict(xi);
        try testing.expect(@abs(prediction - yi) < 0.5);
    }

    const mse = model.loss_function(&x, &y);
    try testing.expect(mse < 0.1);

    const new_x = [_]f64{ 11, 12, 13 };
    const new_y = [_]f64{ 23.1, 24.9, 27.2 };
    const new_mse = model.loss_function(&new_x, &new_y);
    try testing.expect(new_mse < 0.2);

    std.debug.print("\nTrained model: y = {d:.4}x + {d:.4}\n", .{ model.weight, model.bias });
    std.debug.print("Mean Squared Error on training data: {d:.4}\n", .{mse});
    std.debug.print("Mean Squared Error on new data: {d:.4}\n", .{new_mse});
}