Skip to content

Commit 57f4761

Browse files
Pathways-on-Cloud Teamcopybara-github
authored andcommitted
Make Pathways proxy server image user-configurable
The user can optionally pass a custom Pathways proxy server image. This will allow them to use the image corresponding to their head pod. PiperOrigin-RevId: 859207127
1 parent a93a2a6 commit 57f4761

4 files changed

Lines changed: 60 additions & 8 deletions

File tree

pathwaysutils/experimental/shared_pathways_service/isc_pathways.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
_JAX_PLATFORM_PROXY = "proxy"
2828
_JAX_BACKEND_TARGET_KEY = "jax_backend_target"
2929
_JAX_BACKEND_TARGET_HOSTNAME = "grpc://localhost"
30+
_DEFAULT_PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest"
3031

3132
_logger = logging.getLogger(__name__)
3233

@@ -36,6 +37,7 @@ def _deploy_pathways_proxy_server(
3637
proxy_job_name: str,
3738
expected_instances: Mapping[Any, Any],
3839
gcs_scratch_location: str,
40+
proxy_server_image: str,
3941
) -> None:
4042
"""Deploys the Pathways proxy pods to the GKE cluster.
4143
@@ -45,6 +47,7 @@ def _deploy_pathways_proxy_server(
4547
expected_instances: A dictionary mapping instance types to the number of
4648
instances.
4749
gcs_scratch_location: The Google Cloud Storage location to use.
50+
proxy_server_image: The image to use for the proxy server.
4851
4952
Raises:
5053
subprocess.CalledProcessError: If the kubectl command fails.
@@ -70,6 +73,7 @@ def _deploy_pathways_proxy_server(
7073
PATHWAYS_HEAD_PORT=pathways_head_port,
7174
EXPECTED_INSTANCES=instances_str,
7275
GCS_SCRATCH_LOCATION=gcs_scratch_location,
76+
PROXY_SERVER_IMAGE=proxy_server_image,
7377
)
7478

7579
_logger.info("Deploying Pathways proxy: %s", proxy_job_name)
@@ -89,6 +93,8 @@ class _ISCPathways:
8993
pathways_service: The service name and port of the Pathways head pod.
9094
expected_tpu_instances: A dictionary mapping TPU machine types to the number
9195
of instances.
96+
proxy_job_name: The name to use for the deployed proxy.
97+
proxy_server_image: The image to use for the proxy server.
9298
"""
9399

94100
def __init__(
@@ -99,7 +105,8 @@ def __init__(
99105
gcs_bucket: str,
100106
pathways_service: str,
101107
expected_tpu_instances: Mapping[Any, Any],
102-
proxy_job_name: str | None,
108+
proxy_job_name: str,
109+
proxy_server_image: str,
103110
):
104111
"""Initializes the TPU manager."""
105112
self.cluster = cluster
@@ -108,13 +115,10 @@ def __init__(
108115
self.bucket = gcs_bucket
109116
self.pathways_service = pathways_service
110117
self.expected_tpu_instances = expected_tpu_instances
111-
suffix = "".join(
112-
random.choices(string.ascii_lowercase + string.digits, k=5)
113-
)
114-
user = os.environ.get("USER", "user")
115-
self._proxy_job_name = proxy_job_name or f"isc-proxy-{user}-{suffix}"
118+
self._proxy_job_name = proxy_job_name
116119
self._port_forward_process = None
117120
self._proxy_port = None
121+
self.proxy_server_image = proxy_server_image
118122

119123
def __repr__(self):
120124
return (
@@ -133,6 +137,7 @@ def __enter__(self):
133137
proxy_job_name=self._proxy_job_name,
134138
expected_instances=self.expected_tpu_instances,
135139
gcs_scratch_location=self.bucket,
140+
proxy_server_image=self.proxy_server_image,
136141
)
137142
# Print a link to Cloud Logging
138143
cloud_logging_link = gke_utils.get_log_link(
@@ -189,13 +194,15 @@ def _cleanup(self):
189194

190195
@contextlib.contextmanager
191196
def connect(
192-
*, cluster: str,
197+
*,
198+
cluster: str,
193199
project: str,
194200
region: str,
195201
gcs_bucket: str,
196202
pathways_service: str,
197203
expected_tpu_instances: Mapping[str, int],
198204
proxy_job_name: str | None = None,
205+
proxy_server_image: str = _DEFAULT_PROXY_IMAGE,
199206
) -> Iterator["_ISCPathways"]:
200207
"""Connects to a Pathways server if the cluster exists. If not, creates it.
201208
@@ -209,17 +216,26 @@ def connect(
209216
of instances. For example: {"tpuv6e:2x2": 2}
210217
proxy_job_name: The name to use for the deployed proxy. If not provided, a
211218
random name will be generated.
219+
proxy_server_image: The proxy server image to use. If not provided, a
220+
default will be used.
212221
213222
Yields:
214223
The Pathways manager.
215224
"""
216225
_logger.info("Validating Pathways service and TPU instances...")
217226
validators.validate_pathways_service(pathways_service)
218227
validators.validate_tpu_instances(expected_tpu_instances)
228+
validators.validate_proxy_server_image(proxy_server_image)
219229
_logger.info("Validation complete.")
220230
gke_utils.fetch_cluster_credentials(
221231
cluster_name=cluster, project_id=project, location=region
222232
)
233+
proxy_job_name = (
234+
proxy_job_name or f"isc-proxy-{os.environ.get('USER', 'user')}-{''.join(
235+
random.choices(string.ascii_lowercase + string.digits, k=5)
236+
)}"
237+
)
238+
223239
_logger.info("Starting ISCPathways context.")
224240
with _ISCPathways(
225241
cluster=cluster,
@@ -229,5 +245,6 @@ def connect(
229245
pathways_service=pathways_service,
230246
expected_tpu_instances=expected_tpu_instances,
231247
proxy_job_name=proxy_job_name,
248+
proxy_server_image=proxy_server_image,
232249
) as t:
233250
yield t

pathwaysutils/experimental/shared_pathways_service/run_connect_example.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@
2424
"tpu_type", "tpuv6e:2x2", "The TPU machine type and topology."
2525
)
2626
flags.DEFINE_integer("tpu_count", 1, "The number of TPU slices.")
27+
flags.DEFINE_string(
28+
"proxy_job_name",
29+
None,
30+
"The name to use for the GKE job for proxy. If not provided, a random name"
31+
" will be generated.",
32+
)
33+
flags.DEFINE_string(
34+
"proxy_server_image",
35+
None,
36+
"The proxy server image to use. If not provided, a default will be used.",
37+
)
2738

2839
flags.mark_flags_as_required([
2940
"cluster",
@@ -37,13 +48,21 @@
3748
def main(argv: Sequence[str]) -> None:
3849
if len(argv) > 1:
3950
raise app.UsageError("Too many command-line arguments.")
51+
52+
kwargs = {}
53+
if FLAGS.proxy_job_name:
54+
kwargs["proxy_job_name"] = FLAGS.proxy_job_name
55+
if FLAGS.proxy_server_image:
56+
kwargs["proxy_server_image"] = FLAGS.proxy_server_image
57+
4058
with isc_pathways.connect(
4159
cluster=FLAGS.cluster,
4260
project=FLAGS.project,
4361
region=FLAGS.region,
4462
gcs_bucket=FLAGS.gcs_bucket,
4563
pathways_service=FLAGS.pathways_service,
4664
expected_tpu_instances={FLAGS.tpu_type: FLAGS.tpu_count},
65+
**kwargs,
4766
):
4867
orig_matrix = jnp.zeros(5)
4968
result_matrix = orig_matrix + 1

pathwaysutils/experimental/shared_pathways_service/validators.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,19 @@ def validate_tpu_instances(expected_tpu_instances: Mapping[Any, Any]) -> None:
8989

9090
inst = next(iter(expected_tpu_instances.keys()))
9191
_validate_tpu_supported(inst)
92+
93+
94+
def validate_proxy_server_image(proxy_server_image: str) -> None:
95+
"""Validates the proxy server image format."""
96+
if not proxy_server_image or not proxy_server_image.strip():
97+
raise ValueError("Proxy server image cannot be empty.")
98+
if "/" not in proxy_server_image:
99+
raise ValueError(
100+
f"Proxy server image '{proxy_server_image}' must contain '/', "
101+
"separating the registry or namespace from the final image name."
102+
)
103+
if ":" not in proxy_server_image and "@" not in proxy_server_image:
104+
raise ValueError(
105+
f"Proxy server image '{proxy_server_image}' must contain a tag with ':'"
106+
" or a digest with '@'."
107+
)

pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ spec:
1414
automountServiceAccountToken: false
1515
containers:
1616
- name: pathways-proxy
17-
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-0.8.0@sha256:5296fa0819d8cbdfbcf951ffca2072128255411557240624ff4011522a6a2abe
17+
image: ${PROXY_SERVER_IMAGE}
1818
imagePullPolicy: Always
1919
args:
2020
- --server_port=${PROXY_SERVER_PORT}

0 commit comments

Comments
 (0)