|
| 1 | +"""A module for operations on test CKKS evaluation kernels including. |
| 2 | +
|
| 3 | +- ModAdd |
| 4 | +- HEAdd |
| 5 | +- HESub |
| 6 | +- HEMul |
| 7 | +- HERotate |
| 8 | +""" |
| 9 | + |
| 10 | +from concurrent import futures |
| 11 | +from typing import Any, Callable |
| 12 | + |
| 13 | +import jax |
| 14 | +import jax.numpy as jnp |
| 15 | +from jaxite.jaxite_word import add |
| 16 | + |
| 17 | +from absl.testing import absltest |
| 18 | +from absl.testing import parameterized |
| 19 | + |
| 20 | + |
| 21 | +ProcessPoolExecutor = futures.ProcessPoolExecutor |
| 22 | + |
| 23 | +jax.config.update("jax_enable_x64", True) |
| 24 | +jax.config.update("jax_traceback_filtering", "off") |
| 25 | + |
| 26 | + |
| 27 | +class CKKSEvalKernelsTest(parameterized.TestCase): |
| 28 | + """A base class for running bootstrap tests.""" |
| 29 | + |
| 30 | + def __init__(self, *args, **kwargs): |
| 31 | + super(CKKSEvalKernelsTest, self).__init__(*args, **kwargs) |
| 32 | + self.debug = False # dsiable it from printing the test input values |
| 33 | + self.modulus_element_0_tower_0 = 1152921504606748673 |
| 34 | + self.modulus_element_0_tower_1 = 268664833 |
| 35 | + self.modulus_element_0_tower_2 = 557057 |
| 36 | + self.random_key = jax.random.key(0) |
| 37 | + |
| 38 | + def random(self, shape, modulus_list, dtype=jnp.int32): |
| 39 | + assert len(modulus_list) == shape[1] |
| 40 | + |
| 41 | + return jnp.concatenate( |
| 42 | + [ |
| 43 | + jax.random.randint( |
| 44 | + self.random_key, |
| 45 | + shape=(shape[0], 1, shape[2]), |
| 46 | + minval=0, |
| 47 | + maxval=bound, |
| 48 | + dtype=dtype, |
| 49 | + ) |
| 50 | + for bound in modulus_list |
| 51 | + ], |
| 52 | + axis=1, |
| 53 | + ) |
| 54 | + |
| 55 | + @parameterized.named_parameters( |
| 56 | + dict( |
| 57 | + testcase_name="jax_add", |
| 58 | + test_target=add.jax_add, |
| 59 | + modulus_list=[1152921504606748673, 268664833, 557057], |
| 60 | + shape=(2, 3, 16384), # number of elements, number of towers, degree |
| 61 | + ), |
| 62 | + dict( |
| 63 | + testcase_name="vmap_add", |
| 64 | + test_target=add.vmap_add, |
| 65 | + modulus_list=[1152921504606748673, 268664833, 557057], |
| 66 | + shape=(2, 3, 16384), # number of elements, number of towers, degree |
| 67 | + ), |
| 68 | + ) |
| 69 | + def test_add( |
| 70 | + self, |
| 71 | + test_target: Callable[[Any, Any, Any], Any], |
| 72 | + modulus_list=jax.Array, |
| 73 | + shape=tuple[int, int, int], |
| 74 | + ): |
| 75 | + """This function tests the add function using Python native integer data type with arbitrary precision. |
| 76 | +
|
| 77 | + This test finishes in 1.05 second. |
| 78 | +
|
| 79 | + Args: |
| 80 | + test_target: The function to test. |
| 81 | + modulus_list: A jax.Array of integers. |
| 82 | + shape: A tuple of integers representing the shape of the input arrays. |
| 83 | + """ |
| 84 | + # Only test a single element to save comparison time, |
| 85 | + # Correctness-wise, it's sufficient for add. |
| 86 | + value_a = self.random(shape, modulus_list, dtype=jnp.uint64) |
| 87 | + value_b = self.random(shape, modulus_list, dtype=jnp.uint64) |
| 88 | + assert value_a.shape == shape |
| 89 | + assert value_b.shape == shape |
| 90 | + result_a_plus_b = [] |
| 91 | + for element_id in range(value_a.shape[0]): |
| 92 | + result_a_plus_b_one_element = [] |
| 93 | + for tower_id in range(value_a.shape[1]): |
| 94 | + add_res = int(value_b[element_id, tower_id, 0]) + int( |
| 95 | + value_a[element_id, tower_id, 0] |
| 96 | + ) |
| 97 | + if add_res > modulus_list[tower_id]: |
| 98 | + add_res = add_res - modulus_list[tower_id] |
| 99 | + result_a_plus_b_one_element.append(add_res) |
| 100 | + result_a_plus_b.append(result_a_plus_b_one_element) |
| 101 | + result_a_plus_b = jnp.array(result_a_plus_b, dtype=jnp.uint64) |
| 102 | + modulus_list = jnp.array(modulus_list, dtype=jnp.uint64) |
| 103 | + result = test_target(value_a, value_b, modulus_list) |
| 104 | + self.assertEqual(result[:, :, 0].all(), result_a_plus_b.all()) |
| 105 | + |
| 106 | + |
| 107 | +if __name__ == "__main__": |
| 108 | + absltest.main() |
0 commit comments