The Gaudi Framework  master (69a68366)
Loading...
Searching...
No Matches
SubprocessBaseTest.py
Go to the documentation of this file.
11import os
12import re
13import select
14import shutil
15import signal
16import subprocess
17import threading
18from datetime import datetime
19from pathlib import Path
20from string import Template
21from typing import Callable, Dict, List, Optional, Union
22
23import pytest
24
26 ExceededStreamError,
27 FixtureResult,
28 ProcessTimeoutError,
29)
30from GaudiTesting.utils import file_path_for_class, kill_tree, which
31
32STDOUT_LIMIT = int(os.environ.get("GAUDI_TEST_STDOUT_LIMIT", 1024**2 * 100))
33
34
36 """
37 A base class for running and managing subprocess executions within a test framework.
38 It provides mechanisms for setting up the environment, preparing commands,
39 and handling subprocess output and errors.
40 """
41
42 command: List[str] = None
43 reference: str = None
44 environment: List[str] = None
45 timeout: int = 600
46 returncode: int = 0
47 popen_kwargs: Dict = {}
48 # By default record environment variables that are likely to affect job execution
49 env_to_record: List[str] = [
50 r"^(.*PATH|BINARY_TAG|CMTCONFIG|PWD|TMPDIR|DISPLAY|LC_[A-Z]+|LANGUAGE|.*ROOT|.*OPTS|GAUDI_ENV_TO_RECORD(_\d+)?)$"
51 ]
52
53 @property
54 def cwd(self) -> Optional[Path]:
55 cwd = self.popen_kwargs.get("cwd")
56 return Path(cwd) if cwd else None
57
58 @classmethod
59 def resolve_path(cls, path: Union[Path, str]) -> str:
60 """
61 Resolve the given path to an absolute path,
62 expanding environment variables.
63 If path looks relative and does not point to anything
64 it is not modified.
65 """
66 if isinstance(path, Path):
67 path = str(path)
68 path = os.path.expandvars(path)
69
70 # handle the special case "path/to/file:some_suffix"
71 suffix = ""
72 if ":" in path:
73 path, suffix = path.rsplit(":", 1)
74 suffix = f":{suffix}"
75
76 if not os.path.isabs(path):
77 base_dir = file_path_for_class(cls).parent
78 possible_path = str((base_dir / path).resolve())
79 if os.path.exists(possible_path):
80 path = possible_path
81 return path + suffix
82
83 @classmethod
84 def update_env(cls, env: Dict[str, str]) -> None:
85 if cls.environment:
86 for item in cls.environment:
87 key, value = item.split("=", 1)
88 env[key] = cls.expand_vars_from(value, env)
89
90 @classmethod
91 def _prepare_environment(cls) -> Dict[str, str]:
92 env = dict(os.environ)
93 cls.update_env(env)
94 return env
95
96 @staticmethod
97 def expand_vars_from(value: str, env: Dict[str, str]) -> str:
98 return Template(value).safe_substitute(env)
99
100 @staticmethod
101 def unset_vars(env: Dict[str, str], vars_to_unset: List[str]) -> None:
102 for var in vars_to_unset:
103 env.pop(var, None)
104
105 @classmethod
106 def _determine_program(cls, prog: str) -> str:
107 if not any(prog.lower().endswith(ext) for ext in [".exe", ".py", ".bat"]):
108 prog += ".exe"
109 return which(prog) or cls.resolve_path(prog)
110
111 @classmethod
112 def _prepare_command(cls, tmp_path=Path()) -> List[str]:
113 """
114 Prepare the command to be executed, resolving paths for each part.
115 """
116 command = [cls._determine_program(cls.command[0])]
117 for part in cls.command[1:]:
118 if not part.startswith("-"): # do not try to expand options
119 command.append(cls.resolve_path(part))
120 else:
121 command.append(part)
122 return command
123
124 @classmethod
125 def _handle_timeout(cls, proc: subprocess.Popen) -> str:
126 """
127 Handle a process timeout by collecting and returning the stack trace.
128 """
129 stack_trace = cls._collect_stack_trace(proc)
130 cls._terminate_process(proc)
131 return stack_trace
132
133 @staticmethod
134 def _collect_stack_trace(proc: subprocess.Popen) -> str:
135 """
136 Collect stack trace from a running process using the available debugger.
137 Uses lldb on macOS and gdb on Linux. Returns a message if no debugger
138 is available.
139 """
140 import sys
141
142 if sys.platform == "darwin":
143 # Use lldb on macOS
144 if not shutil.which("lldb"):
145 return "(lldb not available for stack trace collection)"
146 cmd = [
147 "lldb",
148 "-p",
149 str(proc.pid),
150 "-o",
151 "bt all",
152 "-o",
153 "quit",
154 ]
155 else:
156 # Use gdb on Linux
157 if not shutil.which("gdb"):
158 return "(gdb not available for stack trace collection)"
159 cmd = [
160 "gdb",
161 "--pid",
162 str(proc.pid),
163 "--batch",
164 "--eval-command=thread apply all backtrace",
165 ]
166 debugger = subprocess.Popen(
167 cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
168 )
169 return debugger.communicate()[0].decode("utf-8", errors="backslashreplace")
170
171 @staticmethod
172 def _terminate_process(proc: subprocess.Popen) -> None:
173 kill_tree(proc.pid, signal.SIGTERM)
174 proc.wait(60)
175 if proc.poll() is None:
176 kill_tree(proc.pid, signal.SIGKILL)
177
178 @classmethod
179 def _prepare_execution(cls, tmp_path=None):
180 command = cls._prepare_command(tmp_path=tmp_path)
181 env = cls._prepare_environment()
182 os.makedirs(cls.popen_kwargs.get("cwd", tmp_path), exist_ok=True)
183 return command, env
184
185 @classmethod
186 def run_program(cls, tmp_path=None) -> FixtureResult:
187 """
188 Run the specified program and capture its output.
189 """
190 start_time = datetime.now()
191 command, env = cls._prepare_execution(tmp_path=tmp_path)
192
193 proc = subprocess.Popen(
194 command,
195 stdout=subprocess.PIPE,
196 stderr=subprocess.PIPE,
197 env=env,
198 **cls.popen_kwargs,
199 )
200
201 stdout_chunks, stderr_chunks = [], []
202 stdout = stderr = ""
203 exceeded_stream = stack_trace = run_exception = None
204 streams = {
205 proc.stdout.fileno(): (stdout_chunks, "stdout"),
206 proc.stderr.fileno(): (stderr_chunks, "stderr"),
207 }
208
209 def read_output():
210 nonlocal stdout, stderr, exceeded_stream
211 while not exceeded_stream and proc.poll() is None:
212 readable, _, _ = select.select(streams.keys(), [], [], cls.timeout)
213 for fileno in readable:
214 data = os.read(fileno, 1024)
215 chunks, stream_name = streams[fileno]
216 chunks.append(data)
217 if sum(len(chunk) for chunk in chunks) > STDOUT_LIMIT:
218 exceeded_stream = stream_name
219 break
220
221 stdout = b"".join(stdout_chunks)
222 stderr = b"".join(stderr_chunks)
223
224 thread = threading.Thread(target=read_output)
225 thread.start()
226 thread.join(cls.timeout)
227
228 if thread.is_alive():
229 stack_trace = cls._handle_timeout(proc)
230 run_exception = ProcessTimeoutError("Process timed out", stack_trace)
231 elif exceeded_stream:
232 run_exception = ExceededStreamError(
233 "Stream exceeded size limit", exceeded_stream
234 )
235
236 end_time = datetime.now()
237
238 completed_process = subprocess.CompletedProcess(
239 args=command,
240 returncode=proc.returncode,
241 stdout=stdout,
242 stderr=stderr,
243 )
244
245 # Only record environment variables that match at least one
246 # of the regular expressions in env_to_record.
247 # If env_to_record is empty or None, do not record any env variable
248 env_to_record = list(cls.env_to_record) if cls.env_to_record else []
249 # extend env_to_record with the regular expressions in environment variables
250 # named GAUDI_ENV_TO_RECORD, GAUDI_ENV_TO_RECORD_1, GAUDI_ENV_TO_RECORD_2, ...
251 env_to_record.extend(
252 env[name] for name in env if re.match(r"GAUDI_ENV_TO_RECORD(_\d+)?", name)
253 )
254 env = {
255 key: value
256 for key, value in env.items()
257 if any(re.match(exp, key) for exp in env_to_record)
258 }
259 return FixtureResult(
260 completed_process=completed_process,
261 start_time=start_time,
262 end_time=end_time,
263 run_exception=run_exception,
264 command=cls.command,
265 expanded_command=command,
266 env=env,
267 cwd=cls.popen_kwargs["cwd"],
268 )
269
270 @classmethod
271 def run_program_for_dbg(cls, tmp_path):
272 tmp_path = cls.popen_kwargs.get("cwd", tmp_path)
273 command, env = cls._prepare_execution(tmp_path=tmp_path)
274 print("Running the command: ", command)
275 subprocess.run(
276 command,
277 env=env,
278 cwd=tmp_path,
279 **cls.popen_kwargs,
280 )
281
282 @pytest.mark.do_not_collect_source
284 self,
285 record_property: Callable[[str, str], None],
286 fixture_result: FixtureResult,
287 reference_path: Optional[Path],
288 ) -> None:
289 """
290 Record properties and handle any failures during fixture setup.
291 """
292 for key, value in fixture_result.to_dict().items():
293 if value is not None:
294 record_property(key, value)
295 if reference_path:
296 record_property("reference_file", str(reference_path))
297
298 if fixture_result.run_exception:
299 pytest.fail(f"{fixture_result.run_exception}")
300
301 @pytest.mark.do_not_collect_source
302 def test_returncode(self, returncode: int) -> None:
303 """
304 Test that the return code matches the expected value.
305 """
306 assert returncode == self.returncode
None unset_vars(Dict[str, str] env, List[str] vars_to_unset)
None test_fixture_setup(self, Callable[[str, str], None] record_property, FixtureResult fixture_result, Optional[Path] reference_path)
str expand_vars_from(str value, Dict[str, str] env)