From 9b9e63476616a123eb4d00f87f677ef2d9478897 Mon Sep 17 00:00:00 2001 From: makefunstuff Date: Thu, 11 Jul 2024 01:11:07 +0300 Subject: fix tests for linear regression --- src/monkey_learns/linear_regression.zig | 3 ++- src/monkey_learns/test.zig | 8 ++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) (limited to 'src/monkey_learns') diff --git a/src/monkey_learns/linear_regression.zig b/src/monkey_learns/linear_regression.zig index 4cc5a9f..17b886d 100644 --- a/src/monkey_learns/linear_regression.zig +++ b/src/monkey_learns/linear_regression.zig @@ -61,7 +61,7 @@ test "Linear Regression" { const y = [_]f64{ 3.1, 4.9, 7.2, 9.1, 11.0, 12.8, 14.9, 17.2, 18.8, 21.1 }; const learning_rate: f64 = 0.01; - const epochs: usize = 1000; + const epochs: usize = 10000; model.train(&x, &y, learning_rate, epochs); try testing.expect(@abs(model.weight - 2.0) < 0.1); @@ -72,6 +72,7 @@ test "Linear Regression" { for (test_x, expected_y) |xi, yi| { const prediction = model.predict(xi); try testing.expect(@abs(prediction - yi) < 0.5); + std.debug.print("Expected: {} | Predicted {}\n", .{ yi, prediction }); } const mse = model.loss_function(&x, &y); diff --git a/src/monkey_learns/test.zig b/src/monkey_learns/test.zig index 7c3472e..47bfdfa 100644 --- a/src/monkey_learns/test.zig +++ b/src/monkey_learns/test.zig @@ -1,3 +1,7 @@ -comptime { - _ = @import("linear_regression.zig"); +const std = @import("std"); +const testing = std.testing; +pub const linear_regression = @import("linear_regression.zig"); + +test { + testing.refAllDeclsRecursive(@This()); } -- cgit 1.4.1-2-gfad0