Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions gemma/gm/nn/vision/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

from __future__ import annotations
from collections.abc import Sequence
import io

import einops
from etils import epath
import jax
from jax import numpy as jnp
from kauldron import typing
import numpy as np
from PIL import Image
import tensorflow as tf

_IMAGE_MEAN = (127.5,) * 3
_IMAGE_STD = (127.5,) * 3
Expand Down Expand Up @@ -69,11 +70,13 @@ def pre_process_image(
Returns:
The pre-processed image.
"""
# all inputs are expected to have been jpeg compressed.
# TODO(eyvinec): we should remove tf dependency.
image = jnp.asarray(
tf.image.decode_jpeg(tf.io.encode_jpeg(image), channels=3)
)
# All inputs are expected to have been JPEG-compressed. Simulate the
# lossy round-trip so pixel values match what the model saw during training.
image_uint8 = np.asarray(image, dtype=np.uint8)
buf = io.BytesIO()
Image.fromarray(image_uint8).save(buf, format="JPEG")
buf.seek(0)
image = jnp.asarray(np.array(Image.open(buf).convert("RGB")))
image = jax.image.resize(
image,
shape=(image_height, image_width, 3),
Expand Down