diff --git a/test/TensorTest.cpp b/test/TensorTest.cpp index 49c9294..e740a4f 100644 --- a/test/TensorTest.cpp +++ b/test/TensorTest.cpp @@ -333,5 +333,136 @@ TEST_F(TensorTest, SymNumel) { EXPECT_EQ(sym_numel, tensor.numel()); } +// 测试 all() - 检查所有元素是否为真(非零) +TEST_F(TensorTest, All) { + FileManerger file(GetTestCaseResultFileName()); + file.createFile(); + + // 测试全1张量 - all() 应返回 true + at::Tensor all_ones = at::ones({2, 2}, at::kInt); + bool result1 = all_ones.all().item(); + file << std::to_string(result1) << " "; + + // 测试全0张量 - all() 应返回 false + at::Tensor all_zeros = at::zeros({2, 2}, at::kInt); + bool result2 = all_zeros.all().item(); + file << std::to_string(result2) << " "; + + // 测试混合张量(有0有1)- all() 应返回 false + std::vector data3 = {1, 0, 1, 1}; + at::Tensor mixed = at::from_blob(data3.data(), {2, 2}, at::kInt).clone(); + bool result3 = mixed.all().item(); + file << std::to_string(result3) << " "; + + // 测试全为负数张量 - all() 应返回 true(非零) + std::vector data4 = {-1, -2, -3, -4}; + at::Tensor all_neg = at::from_blob(data4.data(), {2, 2}, at::kInt).clone(); + bool result4 = all_neg.all().item(); + file << std::to_string(result4) << " "; + + file.saveFile(); +} + +// 测试 all(dim, keepdim) - 沿指定维度检查 +TEST_F(TensorTest, AllDim) { + FileManerger file(GetTestCaseResultFileName()); + file.createFile(); + + std::vector data = {1, 0, 1, 1, 1, 1}; + at::Tensor tensor = at::from_blob(data.data(), {2, 3}, at::kInt).clone(); + + // 沿 dim=0 检查 - 每列所有行 + at::Tensor result_dim0 = tensor.all(0, false); + file << std::to_string(result_dim0.sizes()[0]) << " "; + file << std::to_string(result_dim0.sizes()[1]) << " "; + // 第一列有0,应为false;第二列全为1,应为true;第三列有0,应为false + file << std::to_string(result_dim0[0].item()) << " "; + file << std::to_string(result_dim0[1].item()) << " "; + file << std::to_string(result_dim0[2].item()) << " "; + + // 沿 dim=1 检查 - 每行所有列 + at::Tensor result_dim1 = tensor.all(1, false); + file << std::to_string(result_dim1.sizes()[0]) << " "; + // 第一行有0,应为false;第二行全为1,应为true + file << std::to_string(result_dim1[0].item()) << " "; + file << std::to_string(result_dim1[1].item()) << " "; + + // 测试 keepdim=true + at::Tensor result_keepdim = tensor.all(1, true); + file << std::to_string(result_keepdim.sizes()[0]) << " "; + file << std::to_string(result_keepdim.sizes()[1]) << " "; + + file.saveFile(); +} + +// 测试 all(at::OptionalIntArrayRef dim, bool keepdim) +TEST_F(TensorTest, AllOptionalDim) { + FileManerger file(GetTestCaseResultFileName()); + file.createFile(); + + std::vector data = {1, 0, 1, 1, 1, 1}; + at::Tensor tensor = at::from_blob(data.data(), {2, 3}, at::kInt).clone(); + + // 不指定维度 - 检查所有元素 + at::Tensor result_no_dim = tensor.all(c10::nullopt, false); + file << std::to_string(result_no_dim.item()) << " "; + + // 指定单个维度 + at::Tensor result_single_dim = tensor.all({0}, false); + file << std::to_string(result_single_dim[0].item()) << " "; + file << std::to_string(result_single_dim[1].item()) << " "; + file << std::to_string(result_single_dim[2].item()) << " "; + + // 指定多个维度 + at::Tensor result_multi_dim = tensor.all({0, 1}, false); + file << std::to_string(result_multi_dim.item()) << " "; + + file.saveFile(); +} + +// 测试 allclose - 检查两个张量是否接近 +TEST_F(TensorTest, Allclose) { + FileManerger file(GetTestCaseResultFileName()); + file.createFile(); + + // 测试1: 完全相同的张量 - 应返回 true + std::vector data1 = {1.0f, 2.0f, 3.0f}; + at::Tensor t1 = at::from_blob(data1.data(), {3}, at::kFloat).clone(); + at::Tensor t1_copy = at::from_blob(data1.data(), {3}, at::kFloat).clone(); + bool result1 = t1.allclose(t1_copy); + file << std::to_string(result1) << " "; + + // 测试2: 在默认 rtol/atol 范围内的张量 - 应返回 true + std::vector data2 = {1.0f, 2.0f, 3.0f}; + std::vector data2_slight = {1.0f + 1e-6f, 2.0f - 1e-6f, 3.0f}; + at::Tensor t2 = at::from_blob(data2.data(), {3}, at::kFloat).clone(); + at::Tensor t2_slight = + at::from_blob(data2_slight.data(), {3}, at::kFloat).clone(); + bool result2 = t2.allclose(t2_slight); + file << std::to_string(result2) << " "; + + // 测试3: 超出默认容差的张量 - 应返回 false + std::vector data3 = {1.0f, 2.0f, 3.0f}; + std::vector data3_diff = {1.5f, 2.0f, 3.0f}; // 差异 0.5 > 默认 atol + at::Tensor t3 = at::from_blob(data3.data(), {3}, at::kFloat).clone(); + at::Tensor t3_diff = + at::from_blob(data3_diff.data(), {3}, at::kFloat).clone(); + bool result3 = t3.allclose(t3_diff); + file << std::to_string(result3) << " "; + + // 测试4: 使用较大 rtol 的张量 - 应返回 true + bool result4 = t3.allclose(t3_diff, 0.5, 0.1, false); + file << std::to_string(result4) << " "; + + // 测试7: 多维张量 + std::vector data7 = {1.0f, 2.0f, 3.0f, 4.0f}; + at::Tensor t7 = at::from_blob(data7.data(), {2, 2}, at::kFloat).clone(); + at::Tensor t7_copy = at::from_blob(data7.data(), {2, 2}, at::kFloat).clone(); + bool result7 = t7.allclose(t7_copy); + file << std::to_string(result7) << " "; + + file.saveFile(); +} + } // namespace test } // namespace at