|
7 | 7 | import string |
8 | 8 | import subprocess |
9 | 9 | import sys |
| 10 | +import time |
| 11 | +import random |
10 | 12 | from threading import Timer |
11 | 13 |
|
12 | 14 | import pkg_resources |
@@ -123,7 +125,7 @@ def createCluster(self, |
123 | 125 |
|
124 | 126 | for i in init_actions: |
125 | 127 | if "install_gpu_driver.sh" in i or "horovod.sh" in i or \ |
126 | | - "dask-rapids.sh" in i or "mlvm.sh" in i or \ |
| 128 | + "rapids.sh" in i or "mlvm.sh" in i or \ |
127 | 129 | "spark-rapids.sh" in i: |
128 | 130 | args.append("--no-shielded-secure-boot") |
129 | 131 |
|
@@ -287,10 +289,24 @@ def assert_instance_command(self, |
287 | 289 | AssertionError: if command returned non-0 exit code. |
288 | 290 | """ |
289 | 291 |
|
290 | | - ret_code, stdout, stderr = self.assert_command( |
291 | | - 'gcloud compute ssh {} --zone={} --command="{}"'.format( |
292 | | - instance, self.cluster_zone, cmd), timeout_in_minutes) |
293 | | - return ret_code, stdout, stderr |
| 292 | + retry_count = 5 |
| 293 | + |
| 294 | + ssh_cmd='gcloud compute ssh -q {} --zone={} --command="{}" -- -o ConnectTimeout=60'.format( |
| 295 | + instance, self.cluster_zone, cmd) |
| 296 | + |
| 297 | + while retry_count > 0: |
| 298 | + try: |
| 299 | + ret_code, stdout, stderr = self.assert_command( |
| 300 | + ssh_cmd, timeout_in_minutes ) |
| 301 | + return ret_code, stdout, stderr |
| 302 | + except Exception as e: |
| 303 | + print("An error occurred: ", e) |
| 304 | + retry_count -= 1 |
| 305 | + if retry_count > 0: |
| 306 | + time.sleep( 3 + random.randint(1, 10) ) |
| 307 | + continue |
| 308 | + else: |
| 309 | + raise |
294 | 310 |
|
295 | 311 | def assert_dataproc_job(self, |
296 | 312 | cluster_name, |
|
0 commit comments