2727_JAX_PLATFORM_PROXY = "proxy"
2828_JAX_BACKEND_TARGET_KEY = "jax_backend_target"
2929_JAX_BACKEND_TARGET_HOSTNAME = "grpc://localhost"
30+ _DEFAULT_PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest"
3031
3132_logger = logging .getLogger (__name__ )
3233
@@ -36,6 +37,7 @@ def _deploy_pathways_proxy_server(
3637 proxy_job_name : str ,
3738 expected_instances : Mapping [Any , Any ],
3839 gcs_scratch_location : str ,
40+ proxy_server_image : str ,
3941) -> None :
4042 """Deploys the Pathways proxy pods to the GKE cluster.
4143
@@ -45,6 +47,7 @@ def _deploy_pathways_proxy_server(
4547 expected_instances: A dictionary mapping instance types to the number of
4648 instances.
4749 gcs_scratch_location: The Google Cloud Storage location to use.
50+ proxy_server_image: The image to use for the proxy server.
4851
4952 Raises:
5053 subprocess.CalledProcessError: If the kubectl command fails.
@@ -70,6 +73,7 @@ def _deploy_pathways_proxy_server(
7073 PATHWAYS_HEAD_PORT = pathways_head_port ,
7174 EXPECTED_INSTANCES = instances_str ,
7275 GCS_SCRATCH_LOCATION = gcs_scratch_location ,
76+ PROXY_SERVER_IMAGE = proxy_server_image ,
7377 )
7478
7579 _logger .info ("Deploying Pathways proxy: %s" , proxy_job_name )
@@ -89,6 +93,8 @@ class _ISCPathways:
8993 pathways_service: The service name and port of the Pathways head pod.
9094 expected_tpu_instances: A dictionary mapping TPU machine types to the number
9195 of instances.
96+ proxy_job_name: The name to use for the deployed proxy.
97+ proxy_server_image: The image to use for the proxy server.
9298 """
9399
94100 def __init__ (
@@ -99,7 +105,8 @@ def __init__(
99105 gcs_bucket : str ,
100106 pathways_service : str ,
101107 expected_tpu_instances : Mapping [Any , Any ],
102- proxy_job_name : str | None ,
108+ proxy_job_name : str ,
109+ proxy_server_image : str ,
103110 ):
104111 """Initializes the TPU manager."""
105112 self .cluster = cluster
@@ -108,13 +115,10 @@ def __init__(
108115 self .bucket = gcs_bucket
109116 self .pathways_service = pathways_service
110117 self .expected_tpu_instances = expected_tpu_instances
111- suffix = "" .join (
112- random .choices (string .ascii_lowercase + string .digits , k = 5 )
113- )
114- user = os .environ .get ("USER" , "user" )
115- self ._proxy_job_name = proxy_job_name or f"isc-proxy-{ user } -{ suffix } "
118+ self ._proxy_job_name = proxy_job_name
116119 self ._port_forward_process = None
117120 self ._proxy_port = None
121+ self .proxy_server_image = proxy_server_image
118122
119123 def __repr__ (self ):
120124 return (
@@ -133,6 +137,7 @@ def __enter__(self):
133137 proxy_job_name = self ._proxy_job_name ,
134138 expected_instances = self .expected_tpu_instances ,
135139 gcs_scratch_location = self .bucket ,
140+ proxy_server_image = self .proxy_server_image ,
136141 )
137142 # Print a link to Cloud Logging
138143 cloud_logging_link = gke_utils .get_log_link (
@@ -189,13 +194,15 @@ def _cleanup(self):
189194
190195@contextlib .contextmanager
191196def connect (
192- * , cluster : str ,
197+ * ,
198+ cluster : str ,
193199 project : str ,
194200 region : str ,
195201 gcs_bucket : str ,
196202 pathways_service : str ,
197203 expected_tpu_instances : Mapping [str , int ],
198204 proxy_job_name : str | None = None ,
205+ proxy_server_image : str = _DEFAULT_PROXY_IMAGE ,
199206) -> Iterator ["_ISCPathways" ]:
200207 """Connects to a Pathways server if the cluster exists. If not, creates it.
201208
@@ -209,17 +216,26 @@ def connect(
209216 of instances. For example: {"tpuv6e:2x2": 2}
210217 proxy_job_name: The name to use for the deployed proxy. If not provided, a
211218 random name will be generated.
219+ proxy_server_image: The proxy server image to use. If not provided, a
220+ default will be used.
212221
213222 Yields:
214223 The Pathways manager.
215224 """
216225 _logger .info ("Validating Pathways service and TPU instances..." )
217226 validators .validate_pathways_service (pathways_service )
218227 validators .validate_tpu_instances (expected_tpu_instances )
228+ validators .validate_proxy_server_image (proxy_server_image )
219229 _logger .info ("Validation complete." )
220230 gke_utils .fetch_cluster_credentials (
221231 cluster_name = cluster , project_id = project , location = region
222232 )
233+ proxy_job_name = (
234+ proxy_job_name or f"isc-proxy-{ os .environ .get ('USER' , 'user' )} -{ '' .join (
235+ random .choices (string .ascii_lowercase + string .digits , k = 5 )
236+ )} "
237+ )
238+
223239 _logger .info ("Starting ISCPathways context." )
224240 with _ISCPathways (
225241 cluster = cluster ,
@@ -229,5 +245,6 @@ def connect(
229245 pathways_service = pathways_service ,
230246 expected_tpu_instances = expected_tpu_instances ,
231247 proxy_job_name = proxy_job_name ,
248+ proxy_server_image = proxy_server_image ,
232249 ) as t :
233250 yield t
0 commit comments