about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authormakefunstuff <[email protected]>2024-07-11 00:11:07 +0200
committermakefunstuff <[email protected]>2024-07-11 00:11:07 +0200
commit9b9e63476616a123eb4d00f87f677ef2d9478897 (patch)
treee079d7d390f1ba0f86b97d55e5ab9e4dd2076d86 /src
parentfd3a727f4b4f819178225cdd87d8788b85c4b86c (diff)
downloadtinkerbunk-9b9e63476616a123eb4d00f87f677ef2d9478897.tar.gz
fix tests for linear regression
Diffstat (limited to '')
-rw-r--r--src/monkey_learns/linear_regression.zig3
-rw-r--r--src/monkey_learns/test.zig8
2 files changed, 8 insertions, 3 deletions
diff --git a/src/monkey_learns/linear_regression.zig b/src/monkey_learns/linear_regression.zig
index 4cc5a9f..17b886d 100644
--- a/src/monkey_learns/linear_regression.zig
+++ b/src/monkey_learns/linear_regression.zig
@@ -61,7 +61,7 @@ test "Linear Regression" {
     const y = [_]f64{ 3.1, 4.9, 7.2, 9.1, 11.0, 12.8, 14.9, 17.2, 18.8, 21.1 };
 
     const learning_rate: f64 = 0.01;
-    const epochs: usize = 1000;
+    const epochs: usize = 10000;
     model.train(&x, &y, learning_rate, epochs);
 
     try testing.expect(@abs(model.weight - 2.0) < 0.1);
@@ -72,6 +72,7 @@ test "Linear Regression" {
     for (test_x, expected_y) |xi, yi| {
         const prediction = model.predict(xi);
         try testing.expect(@abs(prediction - yi) < 0.5);
+        std.debug.print("Expected: {} | Predicted {}\n", .{ yi, prediction });
     }
 
     const mse = model.loss_function(&x, &y);
diff --git a/src/monkey_learns/test.zig b/src/monkey_learns/test.zig
index 7c3472e..47bfdfa 100644
--- a/src/monkey_learns/test.zig
+++ b/src/monkey_learns/test.zig
@@ -1,3 +1,7 @@
-comptime {
-    _ = @import("linear_regression.zig");
+const std = @import("std");
+const testing = std.testing;
+pub const linear_regression = @import("linear_regression.zig");
+
+test {
+    testing.refAllDeclsRecursive(@This());
 }