diff options
author | makefunstuff <[email protected]> | 2024-07-11 00:11:07 +0200 |
---|---|---|
committer | makefunstuff <[email protected]> | 2024-07-11 00:11:07 +0200 |
commit | 9b9e63476616a123eb4d00f87f677ef2d9478897 (patch) | |
tree | e079d7d390f1ba0f86b97d55e5ab9e4dd2076d86 | |
parent | fd3a727f4b4f819178225cdd87d8788b85c4b86c (diff) | |
download | tinkerbunk-9b9e63476616a123eb4d00f87f677ef2d9478897.tar.gz |
fix tests for linear regression
-rw-r--r-- | src/monkey_learns/linear_regression.zig | 3 | ||||
-rw-r--r-- | src/monkey_learns/test.zig | 8 |
2 files changed, 8 insertions, 3 deletions
diff --git a/src/monkey_learns/linear_regression.zig b/src/monkey_learns/linear_regression.zig index 4cc5a9f..17b886d 100644 --- a/src/monkey_learns/linear_regression.zig +++ b/src/monkey_learns/linear_regression.zig @@ -61,7 +61,7 @@ test "Linear Regression" { const y = [_]f64{ 3.1, 4.9, 7.2, 9.1, 11.0, 12.8, 14.9, 17.2, 18.8, 21.1 }; const learning_rate: f64 = 0.01; - const epochs: usize = 1000; + const epochs: usize = 10000; model.train(&x, &y, learning_rate, epochs); try testing.expect(@abs(model.weight - 2.0) < 0.1); @@ -72,6 +72,7 @@ test "Linear Regression" { for (test_x, expected_y) |xi, yi| { const prediction = model.predict(xi); try testing.expect(@abs(prediction - yi) < 0.5); + std.debug.print("Expected: {} | Predicted {}\n", .{ yi, prediction }); } const mse = model.loss_function(&x, &y); diff --git a/src/monkey_learns/test.zig b/src/monkey_learns/test.zig index 7c3472e..47bfdfa 100644 --- a/src/monkey_learns/test.zig +++ b/src/monkey_learns/test.zig @@ -1,3 +1,7 @@ -comptime { - _ = @import("linear_regression.zig"); +const std = @import("std"); +const testing = std.testing; +pub const linear_regression = @import("linear_regression.zig"); + +test { + testing.refAllDeclsRecursive(@This()); } |