diff --git a/gemma/gm/nn/vision/_image.py b/gemma/gm/nn/vision/_image.py index ce4182e0..ebbf39dc 100644 --- a/gemma/gm/nn/vision/_image.py +++ b/gemma/gm/nn/vision/_image.py @@ -16,6 +16,8 @@ from __future__ import annotations from collections.abc import Sequence +import io + import einops from etils import epath import jax @@ -23,7 +25,6 @@ from kauldron import typing import numpy as np from PIL import Image -import tensorflow as tf _IMAGE_MEAN = (127.5,) * 3 _IMAGE_STD = (127.5,) * 3 @@ -69,11 +70,13 @@ def pre_process_image( Returns: The pre-processed image. """ - # all inputs are expected to have been jpeg compressed. - # TODO(eyvinec): we should remove tf dependency. - image = jnp.asarray( - tf.image.decode_jpeg(tf.io.encode_jpeg(image), channels=3) - ) + # All inputs are expected to have been JPEG-compressed. Simulate the + # lossy round-trip so pixel values match what the model saw during training. + image_uint8 = np.asarray(image, dtype=np.uint8) + buf = io.BytesIO() + Image.fromarray(image_uint8).save(buf, format="JPEG") + buf.seek(0) + image = jnp.asarray(np.array(Image.open(buf).convert("RGB"))) image = jax.image.resize( image, shape=(image_height, image_width, 3),