about summary refs log tree commit diff
path: root/src/monkey_learns
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/monkey_learns/linear_regression.zig3
-rw-r--r--src/monkey_learns/test.zig8
2 files changed, 8 insertions, 3 deletions
diff --git a/src/monkey_learns/linear_regression.zig b/src/monkey_learns/linear_regression.zig
index 4cc5a9f..17b886d 100644
--- a/src/monkey_learns/linear_regression.zig
+++ b/src/monkey_learns/linear_regression.zig
@@ -61,7 +61,7 @@ test "Linear Regression" {
     const y = [_]f64{ 3.1, 4.9, 7.2, 9.1, 11.0, 12.8, 14.9, 17.2, 18.8, 21.1 };
 
     const learning_rate: f64 = 0.01;
-    const epochs: usize = 1000;
+    const epochs: usize = 10000;
     model.train(&x, &y, learning_rate, epochs);
 
     try testing.expect(@abs(model.weight - 2.0) < 0.1);
@@ -72,6 +72,7 @@ test "Linear Regression" {
     for (test_x, expected_y) |xi, yi| {
         const prediction = model.predict(xi);
         try testing.expect(@abs(prediction - yi) < 0.5);
+        std.debug.print("Expected: {} | Predicted {}\n", .{ yi, prediction });
     }
 
     const mse = model.loss_function(&x, &y);
diff --git a/src/monkey_learns/test.zig b/src/monkey_learns/test.zig
index 7c3472e..47bfdfa 100644
--- a/src/monkey_learns/test.zig
+++ b/src/monkey_learns/test.zig
@@ -1,3 +1,7 @@
-comptime {
-    _ = @import("linear_regression.zig");
+const std = @import("std");
+const testing = std.testing;
+pub const linear_regression = @import("linear_regression.zig");
+
+test {
+    testing.refAllDeclsRecursive(@This());
 }