Skip to content

Commit eccfdb8

Browse files
JianmingTONGcopybara-github
authored andcommitted
Finish jaxiteword emitter
PiperOrigin-RevId: 727553349
1 parent 43e7ae0 commit eccfdb8

3 files changed

Lines changed: 182 additions & 0 deletions

File tree

BUILD

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,22 @@ tpu_test(
178178
],
179179
)
180180

181+
tpu_test(
182+
name = "add_test",
183+
size = "large",
184+
timeout = "eternal",
185+
srcs = ["jaxite_word/add_test.py"],
186+
shard_count = 3,
187+
deps = [
188+
":jaxite",
189+
"@com_google_absl_py//absl/testing:absltest",
190+
"@com_google_absl_py//absl/testing:parameterized",
191+
"@jaxite_deps_jax//:pkg",
192+
"@jaxite_deps_jaxlib//:pkg",
193+
"@jaxite_deps_numpy//:pkg",
194+
],
195+
)
196+
181197
cpu_gpu_tpu_test(
182198
name = "decomposition_test",
183199
size = "small",

jaxite_word/add.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""TPU kernels for Evaluation of the CKKS algorithm."""
2+
3+
import jax
4+
import jax.numpy as jnp
5+
6+
7+
def jax_add(value_a: jax.Array, value_b: jax.Array, modulus_list: jax.Array):
8+
"""This function processes all degree of the two input polynomials in parallel using multi-trheading.
9+
10+
Assuming the input data type is jax array.
11+
12+
Args:
13+
value_a: the first operand of the addition.
14+
value_b: the second operand of the addition.
15+
modulus_list: the list of moduli for each degree.
16+
17+
Returns:
18+
The result of the addition.
19+
"""
20+
num_elements, _, degree = value_a.shape
21+
modulus_broadcast = jnp.tile(
22+
modulus_list[None, :, None], (num_elements, 1, degree)
23+
)
24+
result = value_a + value_b
25+
return jnp.where(
26+
result > modulus_broadcast, result - modulus_broadcast, result
27+
) # jnp.mod(value_a + value_b, modulus_broadcast)
28+
29+
30+
def vmap_add(
31+
value_a: jax.Array, value_b: jax.Array, modulus_list: jax.Array
32+
):
33+
"""This function processes all degree of the two input polynomials in SIMD using jax.vmap.
34+
35+
Assuming the input data type is jax array.
36+
37+
Args:
38+
value_a: the first operand of the addition.
39+
value_b: the second operand of the addition.
40+
modulus_list: the list of moduli for each degree.
41+
42+
Returns:
43+
The result of the addition.
44+
"""
45+
num_elements, num_towers, degree = value_a.shape
46+
#ToDo: expand api into four dimensions array with num_ciphertexts, num_towers, degree, num_elements
47+
modulus_broadcast = jnp.tile(
48+
modulus_list[None, :, None], (num_elements, 1, degree)
49+
)
50+
51+
def chunk_wise_add(value_a, value_b):
52+
return value_a + value_b
53+
54+
def chunk_wise_subtract(value_a, value_b):
55+
return jnp.where(value_a > value_b, value_a - value_b, value_a)
56+
57+
result = jax.vmap(chunk_wise_add)(value_a, value_b)
58+
return jax.vmap(chunk_wise_subtract)(result, modulus_broadcast)

jaxite_word/add_test.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""A module for operations on test CKKS evaluation kernels including.
2+
3+
- ModAdd
4+
- HEAdd
5+
- HESub
6+
- HEMul
7+
- HERotate
8+
"""
9+
10+
from concurrent import futures
11+
from typing import Any, Callable
12+
13+
import jax
14+
import jax.numpy as jnp
15+
from jaxite.jaxite_word import add
16+
17+
from absl.testing import absltest
18+
from absl.testing import parameterized
19+
20+
21+
ProcessPoolExecutor = futures.ProcessPoolExecutor
22+
23+
jax.config.update("jax_enable_x64", True)
24+
jax.config.update("jax_traceback_filtering", "off")
25+
26+
27+
class CKKSEvalKernelsTest(parameterized.TestCase):
28+
"""A base class for running bootstrap tests."""
29+
30+
def __init__(self, *args, **kwargs):
31+
super(CKKSEvalKernelsTest, self).__init__(*args, **kwargs)
32+
self.debug = False # dsiable it from printing the test input values
33+
self.modulus_element_0_tower_0 = 1152921504606748673
34+
self.modulus_element_0_tower_1 = 268664833
35+
self.modulus_element_0_tower_2 = 557057
36+
self.random_key = jax.random.key(0)
37+
38+
def random(self, shape, modulus_list, dtype=jnp.int32):
39+
assert len(modulus_list) == shape[1]
40+
41+
return jnp.concatenate(
42+
[
43+
jax.random.randint(
44+
self.random_key,
45+
shape=(shape[0], 1, shape[2]),
46+
minval=0,
47+
maxval=bound,
48+
dtype=dtype,
49+
)
50+
for bound in modulus_list
51+
],
52+
axis=1,
53+
)
54+
55+
@parameterized.named_parameters(
56+
dict(
57+
testcase_name="jax_add",
58+
test_target=add.jax_add,
59+
modulus_list=[1152921504606748673, 268664833, 557057],
60+
shape=(2, 3, 16384), # number of elements, number of towers, degree
61+
),
62+
dict(
63+
testcase_name="vmap_add",
64+
test_target=add.vmap_add,
65+
modulus_list=[1152921504606748673, 268664833, 557057],
66+
shape=(2, 3, 16384), # number of elements, number of towers, degree
67+
),
68+
)
69+
def test_add(
70+
self,
71+
test_target: Callable[[Any, Any, Any], Any],
72+
modulus_list=jax.Array,
73+
shape=tuple[int, int, int],
74+
):
75+
"""This function tests the add function using Python native integer data type with arbitrary precision.
76+
77+
This test finishes in 1.05 second.
78+
79+
Args:
80+
test_target: The function to test.
81+
modulus_list: A jax.Array of integers.
82+
shape: A tuple of integers representing the shape of the input arrays.
83+
"""
84+
# Only test a single element to save comparison time,
85+
# Correctness-wise, it's sufficient for add.
86+
value_a = self.random(shape, modulus_list, dtype=jnp.uint64)
87+
value_b = self.random(shape, modulus_list, dtype=jnp.uint64)
88+
assert value_a.shape == shape
89+
assert value_b.shape == shape
90+
result_a_plus_b = []
91+
for element_id in range(value_a.shape[0]):
92+
result_a_plus_b_one_element = []
93+
for tower_id in range(value_a.shape[1]):
94+
add_res = int(value_b[element_id, tower_id, 0]) + int(
95+
value_a[element_id, tower_id, 0]
96+
)
97+
if add_res > modulus_list[tower_id]:
98+
add_res = add_res - modulus_list[tower_id]
99+
result_a_plus_b_one_element.append(add_res)
100+
result_a_plus_b.append(result_a_plus_b_one_element)
101+
result_a_plus_b = jnp.array(result_a_plus_b, dtype=jnp.uint64)
102+
modulus_list = jnp.array(modulus_list, dtype=jnp.uint64)
103+
result = test_target(value_a, value_b, modulus_list)
104+
self.assertEqual(result[:, :, 0].all(), result_a_plus_b.all())
105+
106+
107+
if __name__ == "__main__":
108+
absltest.main()

0 commit comments

Comments
 (0)