Skip to content

Commit e949b0d

Browse files
committed
Refactor and cleanup restartthinner
1 parent 4ac49d5 commit e949b0d

2 files changed

Lines changed: 255 additions & 112 deletions

File tree

src/subscript/restartthinner/restartthinner.py

Lines changed: 131 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,20 @@
22

33
import argparse
44
import datetime
5-
import glob
6-
import os
5+
import logging
76
import shutil
8-
import sys
7+
import subprocess
98
import tempfile
9+
from contextlib import chdir
1010
from pathlib import Path
1111

12-
import numpy
13-
import pandas
12+
import numpy as np
13+
import pandas as pd
1414
from resdata.resfile import ResdataFile
1515

16-
from subscript import __version__
16+
from subscript import __version__, getLogger
17+
18+
logger = getLogger(__name__)
1719

1820
DESCRIPTION = """
1921
Slice a subset of restart-dates from an E100 Restart file (UNRST)
@@ -28,97 +30,100 @@
2830

2931

3032
def find_resdata_app(toolname: str) -> str:
31-
"""Locate path of apps in resdata.
32-
33-
These have varying suffixes due through the history of resdata Makefiles.
33+
"""Locate path of resdata apps, trying common suffixes (.x, .c.x, .cpp.x).
3434
35-
Depending on resdata-version, it has the .x or the .c.x suffix
36-
We prefer .x.
35+
Args:
36+
toolname: Base name of the tool (e.g., 'rd_unpack')
3737
3838
Returns:
39-
String with path if found.
39+
Full path to the executable.
4040
4141
Raises:
42-
IOError: if tool can't be found
42+
OSError: If tool cannot be found in PATH.
4343
"""
44-
extensions = [".x", ".c.x", ".cpp.x", ""] # Order matters.
45-
candidates = [toolname + extension for extension in extensions]
46-
for candidate in candidates:
47-
for path in os.environ["PATH"].split(os.pathsep):
48-
candidatepath = Path(path) / candidate
49-
if candidatepath.exists():
50-
return str(candidatepath)
51-
raise OSError(toolname + " not found in path, PATH=" + str(os.environ["PATH"]))
52-
53-
54-
def date_slicer(slicedates: list, restartdates: list, restartindices: list) -> dict:
55-
"""Make a dict that maps a chosen restart date to a report index"""
56-
slicedatemap = {}
44+
for ext in [".x", ".c.x", ".cpp.x", ""]: # Order matters.
45+
if path := shutil.which(toolname + ext):
46+
return path
47+
raise OSError(f"{toolname} not found in PATH")
48+
49+
50+
def date_slicer(
51+
slicedates: list[pd.Timestamp],
52+
restartdates: list[datetime.datetime],
53+
restartindices: list[int],
54+
) -> list[int]:
55+
"""Make a list of report indices that match the input slicedates."""
56+
slicedatelist = []
5757
for slicedate in slicedates:
58-
daydistances = [
59-
abs((pandas.Timestamp(slicedate) - x).days) for x in restartdates
60-
]
61-
slicedatemap[slicedate] = restartindices[daydistances.index(min(daydistances))]
62-
return slicedatemap
58+
daydistances = [abs((pd.Timestamp(slicedate) - x).days) for x in restartdates]
59+
slicedatelist.append(restartindices[daydistances.index(min(daydistances))])
60+
return slicedatelist
61+
62+
63+
def rd_repacker(rstfilename: str, slicerstindices: list[int], quiet: bool) -> None:
64+
"""Repack a UNRST file keeping only selected restart indices.
6365
66+
Uses rd_unpack and rd_pack utilities from resdata to unpack the UNRST file,
67+
remove unwanted dates, and repack into a new UNRST file.
6468
65-
def rd_repacker(rstfilename: str, slicerstindices: list, quiet: bool) -> None:
69+
Args:
70+
rstfilename: Path to the UNRST file.
71+
slicerstindices: List of restart indices to keep.
72+
quiet: If True, suppress subprocess output.
73+
74+
Raises:
75+
OSError: If rd_unpack or rd_pack tools are not found.
6676
"""
67-
Wrapper for ecl_unpack.x and ecl_pack.x utilities. These
68-
utilities are from resdata.
77+
rd_unpack = find_resdata_app("rd_unpack")
78+
rd_pack = find_resdata_app("rd_pack")
79+
80+
rstpath = Path(rstfilename)
81+
rstdir = rstpath.parent or Path(".")
82+
rstname = rstpath.name
83+
84+
with chdir(rstdir):
85+
tempdir = Path(tempfile.mkdtemp(dir="."))
86+
try:
87+
# Move UNRST into temp directory and work there
88+
shutil.move(rstname, tempdir / rstname)
89+
90+
with chdir(tempdir):
91+
subprocess.run(
92+
[rd_unpack, rstname],
93+
stdout=subprocess.DEVNULL if quiet else None,
94+
check=True,
95+
)
96+
97+
for file in Path(".").glob("*.X*"):
98+
index = int(file.suffix.lstrip(".X"))
99+
if index not in slicerstindices:
100+
file.unlink()
101+
102+
remaining_files = sorted(Path(".").glob("*.X*"))
103+
subprocess.run(
104+
[rd_pack, *[str(f) for f in remaining_files]],
105+
stdout=subprocess.DEVNULL if quiet else None,
106+
check=True,
107+
)
108+
109+
# Move result back up
110+
shutil.move(rstname, Path("..") / rstname)
111+
finally:
112+
shutil.rmtree(tempdir)
113+
114+
115+
def get_restart_indices(rstfilename: str) -> list[int]:
116+
"""Extract a list of restart indices for a filename.
117+
118+
Args:
119+
rstfilename: Path to the UNRST file.
69120
70-
First unpacking a UNRST file, then deleting dates the dont't want, then
71-
pack the remainding files into a new UNRST file
121+
Returns:
122+
List of restart report indices.
72123
73-
This function will change working directory to the
74-
location of the UNRST file, dump temporary files in there, and
75-
modify the original filename.
124+
Raises:
125+
FileNotFoundError: If the file does not exist.
76126
"""
77-
out = " >/dev/null" if quiet else ""
78-
# Error early if resdata tools are not available
79-
try:
80-
find_resdata_app("rd_unpack")
81-
find_resdata_app("rd_pack")
82-
except OSError:
83-
sys.exit(
84-
"ERROR: rd_unpack.x and/or rd_pack.x not found.\n"
85-
"These tools are required and must be installed separately"
86-
)
87-
88-
# Take special care if the UNRST file we get in is not in current directory
89-
cwd = os.getcwd()
90-
rstfilepath = Path(rstfilename).parent
91-
tempdir = None
92-
93-
try:
94-
os.chdir(Path(rstfilename).parent)
95-
tempdir = tempfile.mkdtemp(dir=".")
96-
os.rename(
97-
os.path.basename(rstfilename),
98-
os.path.join(tempdir, os.path.basename(rstfilename)),
99-
)
100-
os.chdir(tempdir)
101-
os.system(
102-
find_resdata_app("rd_unpack") + " " + os.path.basename(rstfilename) + out
103-
)
104-
unpackedfiles = glob.glob("*.X*")
105-
for file in unpackedfiles:
106-
if int(file.split(".X")[1]) not in slicerstindices:
107-
os.remove(file)
108-
os.system(find_resdata_app("rd_pack") + " *.X*" + out)
109-
# We are inside the tmp directory, move file one step up:
110-
os.rename(
111-
os.path.join(os.getcwd(), os.path.basename(rstfilename)),
112-
os.path.join(os.getcwd(), "../", os.path.basename(rstfilename)),
113-
)
114-
finally:
115-
os.chdir(cwd)
116-
if tempdir is not None:
117-
shutil.rmtree(rstfilepath / tempdir)
118-
119-
120-
def get_restart_indices(rstfilename: str) -> list:
121-
"""Extract a list of RST indices for a filename"""
122127
if Path(rstfilename).exists():
123128
# This function segfaults if file does not exist
124129
return ResdataFile.file_report_list(str(rstfilename))
@@ -132,8 +137,14 @@ def restartthinner(
132137
dryrun: bool = True,
133138
keep: bool = False,
134139
) -> None:
135-
"""
136-
Thin an existing UNRST file to selected number of restarts.
140+
"""Thin an existing UNRST file to selected number of restarts.
141+
142+
Args:
143+
filename: Path to the UNRST file.
144+
numberofslices: Number of restart dates to keep.
145+
quiet: If True, suppress informational output.
146+
dryrun: If True, only show what would be done without modifying files.
147+
keep: If True, keep original file with .orig suffix.
137148
"""
138149
rst = ResdataFile(filename)
139150
restart_indices = get_restart_indices(filename)
@@ -142,41 +153,39 @@ def restartthinner(
142153
]
143154

144155
if numberofslices > 1:
145-
slicedates = pandas.DatetimeIndex(
146-
numpy.linspace(
147-
pandas.Timestamp(restart_dates[0]).value,
148-
pandas.Timestamp(restart_dates[-1]).value,
156+
slicedates = pd.DatetimeIndex(
157+
np.linspace(
158+
pd.Timestamp(restart_dates[0]).value,
159+
pd.Timestamp(restart_dates[-1]).value,
149160
int(numberofslices),
150161
)
151162
).to_list()
152163
else:
153164
slicedates = [restart_dates[-1]] # Only return last date if only one is wanted
154165

155-
slicerstindices = list(
156-
date_slicer(slicedates, restart_dates, restart_indices).values()
157-
)
158-
slicerstindices.sort()
159-
slicerstindices = list(set(slicerstindices)) # uniquify
166+
slicerstindices = date_slicer(slicedates, restart_dates, restart_indices)
167+
slicerstindices = sorted(set(slicerstindices)) # uniquify
160168

161169
if not quiet:
162-
print("Selected restarts:")
163-
print("-----------------------")
170+
logger.info("Selected restarts:")
171+
logger.info("-----------------------")
164172
for idx, rstidx in enumerate(restart_indices):
165173
slicepresent = "X" if rstidx in slicerstindices else ""
166-
print(
167-
f"{rstidx:4d} "
168-
f"{datetime.date.strftime(restart_dates[idx], '%Y-%m-%d')} "
169-
f"{slicepresent}"
174+
logger.info(
175+
"%4d %s %s",
176+
rstidx,
177+
datetime.date.strftime(restart_dates[idx], "%Y-%m-%d"),
178+
slicepresent,
170179
)
171-
print("-----------------------")
180+
logger.info("-----------------------")
181+
172182
if not dryrun:
173183
if keep:
174184
backupname = filename + ".orig"
175-
if not quiet:
176-
print(f"Info: Backing up {filename} to {backupname}")
185+
logger.info("Backing up %s to %s", filename, backupname)
177186
shutil.copyfile(filename, backupname)
178187
rd_repacker(filename, slicerstindices, quiet)
179-
print(f"Written to {filename}")
188+
logger.info("Written to %s", filename)
180189

181190

182191
def get_parser() -> argparse.ArgumentParser:
@@ -186,7 +195,11 @@ def get_parser() -> argparse.ArgumentParser:
186195
)
187196
parser.add_argument("UNRST", help="Name of UNRST file")
188197
parser.add_argument(
189-
"-n", "--restarts", type=int, help="Number of restart dates wanted", default=0
198+
"-n",
199+
"--restarts",
200+
type=int,
201+
help="Number of restart dates wanted",
202+
required=True,
190203
)
191204
parser.add_argument(
192205
"-d",
@@ -218,13 +231,19 @@ def get_parser() -> argparse.ArgumentParser:
218231

219232

220233
def main() -> None:
221-
"""Endpoint for command line script"""
234+
"""Endpoint for command line script."""
222235
parser = get_parser()
223236
args = parser.parse_args()
237+
224238
if args.restarts <= 0:
225-
print("ERROR: Number of restarts must be a positive number")
226-
sys.exit(1)
227-
if args.UNRST.endswith("DATA"):
228-
print("ERROR: Provide the UNRST file, not the DATA file")
229-
sys.exit(1)
239+
parser.error("Number of restarts must be a positive number")
240+
if args.UNRST.endswith(".DATA"):
241+
parser.error("Provide the UNRST file, not the DATA file")
242+
if args.quiet:
243+
logger.setLevel(logging.WARNING)
244+
230245
restartthinner(args.UNRST, args.restarts, args.quiet, args.dryrun, args.keep)
246+
247+
248+
if __name__ == "__main__":
249+
main()

0 commit comments

Comments
 (0)