The Gaudi Framework  master (181af51f)
Loading...
Searching...
No Matches
SubprocessBaseTest.py
Go to the documentation of this file.
11import os
12import re
13import select
14import signal
15import subprocess
16import threading
17from datetime import datetime
18from pathlib import Path
19from string import Template
20from typing import Callable, Dict, List, Optional, Union
21
22import pytest
23
25 ExceededStreamError,
26 FixtureResult,
27 ProcessTimeoutError,
28)
29from GaudiTesting.utils import file_path_for_class, kill_tree, which
30
31STDOUT_LIMIT = int(os.environ.get("GAUDI_TEST_STDOUT_LIMIT", 1024**2 * 100))
32
33
35 """
36 A base class for running and managing subprocess executions within a test framework.
37 It provides mechanisms for setting up the environment, preparing commands,
38 and handling subprocess output and errors.
39 """
40
41 command: List[str] = None
42 reference: str = None
43 environment: List[str] = None
44 timeout: int = 600
45 returncode: int = 0
46 popen_kwargs: Dict = {}
47 # By default record environment variables that are likely to affect job execution
48 env_to_record: List[str] = [
49 r"^(.*PATH|BINARY_TAG|CMTCONFIG|PWD|TMPDIR|DISPLAY|LC_[A-Z]+|LANGUAGE|.*ROOT|.*OPTS|GAUDI_ENV_TO_RECORD(_\d+)?)$"
50 ]
51
52 @property
53 def cwd(self) -> Optional[Path]:
54 cwd = self.popen_kwargs.get("cwd")
55 return Path(cwd) if cwd else None
56
57 @classmethod
58 def resolve_path(cls, path: Union[Path, str]) -> str:
59 """
60 Resolve the given path to an absolute path,
61 expanding environment variables.
62 If path looks relative and does not point to anything
63 it is not modified.
64 """
65 if isinstance(path, Path):
66 path = str(path)
67 path = os.path.expandvars(path)
68
69 # handle the special case "path/to/file:some_suffix"
70 suffix = ""
71 if ":" in path:
72 path, suffix = path.rsplit(":", 1)
73 suffix = f":{suffix}"
74
75 if not os.path.isabs(path):
76 base_dir = file_path_for_class(cls).parent
77 possible_path = str((base_dir / path).resolve())
78 if os.path.exists(possible_path):
79 path = possible_path
80 return path + suffix
81
82 @classmethod
83 def update_env(cls, env: Dict[str, str]) -> None:
84 if cls.environment:
85 for item in cls.environment:
86 key, value = item.split("=", 1)
87 env[key] = cls.expand_vars_from(value, env)
88
89 @classmethod
90 def _prepare_environment(cls) -> Dict[str, str]:
91 env = dict(os.environ)
92 cls.update_env(env)
93 return env
94
95 @staticmethod
96 def expand_vars_from(value: str, env: Dict[str, str]) -> str:
97 return Template(value).safe_substitute(env)
98
99 @staticmethod
100 def unset_vars(env: Dict[str, str], vars_to_unset: List[str]) -> None:
101 for var in vars_to_unset:
102 env.pop(var, None)
103
104 @classmethod
105 def _determine_program(cls, prog: str) -> str:
106 if not any(prog.lower().endswith(ext) for ext in [".exe", ".py", ".bat"]):
107 prog += ".exe"
108 return which(prog) or cls.resolve_path(prog)
109
110 @classmethod
111 def _prepare_command(cls, tmp_path=Path()) -> List[str]:
112 """
113 Prepare the command to be executed, resolving paths for each part.
114 """
115 command = [cls._determine_program(cls.command[0])]
116 for part in cls.command[1:]:
117 if not part.startswith("-"): # do not try to expand options
118 command.append(cls.resolve_path(part))
119 else:
120 command.append(part)
121 return command
122
123 @classmethod
124 def _handle_timeout(cls, proc: subprocess.Popen) -> str:
125 """
126 Handle a process timeout by collecting and returning the stack trace.
127 """
128 stack_trace = cls._collect_stack_trace(proc)
129 cls._terminate_process(proc)
130 return stack_trace
131
132 @staticmethod
133 def _collect_stack_trace(proc: subprocess.Popen) -> str:
134 cmd = [
135 "gdb",
136 "--pid",
137 str(proc.pid),
138 "--batch",
139 "--eval-command=thread apply all backtrace",
140 ]
141 gdb = subprocess.Popen(
142 cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
143 )
144 return gdb.communicate()[0].decode("utf-8", errors="backslashreplace")
145
146 @staticmethod
147 def _terminate_process(proc: subprocess.Popen) -> None:
148 kill_tree(proc.pid, signal.SIGTERM)
149 proc.wait(60)
150 if proc.poll() is None:
151 kill_tree(proc.pid, signal.SIGKILL)
152
153 @classmethod
154 def _prepare_execution(cls, tmp_path=None):
155 command = cls._prepare_command(tmp_path=tmp_path)
156 env = cls._prepare_environment()
157 os.makedirs(cls.popen_kwargs.get("cwd", tmp_path), exist_ok=True)
158 return command, env
159
160 @classmethod
161 def run_program(cls, tmp_path=None) -> FixtureResult:
162 """
163 Run the specified program and capture its output.
164 """
165 start_time = datetime.now()
166 command, env = cls._prepare_execution(tmp_path=tmp_path)
167
168 proc = subprocess.Popen(
169 command,
170 stdout=subprocess.PIPE,
171 stderr=subprocess.PIPE,
172 env=env,
173 **cls.popen_kwargs,
174 )
175
176 stdout_chunks, stderr_chunks = [], []
177 stdout = stderr = ""
178 exceeded_stream = stack_trace = run_exception = None
179 streams = {
180 proc.stdout.fileno(): (stdout_chunks, "stdout"),
181 proc.stderr.fileno(): (stderr_chunks, "stderr"),
182 }
183
184 def read_output():
185 nonlocal stdout, stderr, exceeded_stream
186 while not exceeded_stream and proc.poll() is None:
187 readable, _, _ = select.select(streams.keys(), [], [], cls.timeout)
188 for fileno in readable:
189 data = os.read(fileno, 1024)
190 chunks, stream_name = streams[fileno]
191 chunks.append(data)
192 if sum(len(chunk) for chunk in chunks) > STDOUT_LIMIT:
193 exceeded_stream = stream_name
194 break
195
196 stdout = b"".join(stdout_chunks)
197 stderr = b"".join(stderr_chunks)
198
199 thread = threading.Thread(target=read_output)
200 thread.start()
201 thread.join(cls.timeout)
202
203 if thread.is_alive():
204 stack_trace = cls._handle_timeout(proc)
205 run_exception = ProcessTimeoutError("Process timed out", stack_trace)
206 elif exceeded_stream:
207 run_exception = ExceededStreamError(
208 "Stream exceeded size limit", exceeded_stream
209 )
210
211 end_time = datetime.now()
212
213 completed_process = subprocess.CompletedProcess(
214 args=command,
215 returncode=proc.returncode,
216 stdout=stdout,
217 stderr=stderr,
218 )
219
220 # Only record environment variables that match at least one
221 # of the regular expressions in env_to_record.
222 # If env_to_record is empty or None, do not record any env variable
223 env_to_record = list(cls.env_to_record) if cls.env_to_record else []
224 # extend env_to_record with the regular expressions in environment variables
225 # named GAUDI_ENV_TO_RECORD, GAUDI_ENV_TO_RECORD_1, GAUDI_ENV_TO_RECORD_2, ...
226 env_to_record.extend(
227 env[name] for name in env if re.match(r"GAUDI_ENV_TO_RECORD(_\d+)?", name)
228 )
229 env = {
230 key: value
231 for key, value in env.items()
232 if any(re.match(exp, key) for exp in env_to_record)
233 }
234 return FixtureResult(
235 completed_process=completed_process,
236 start_time=start_time,
237 end_time=end_time,
238 run_exception=run_exception,
239 command=cls.command,
240 expanded_command=command,
241 env=env,
242 cwd=cls.popen_kwargs["cwd"],
243 )
244
245 @classmethod
246 def run_program_for_dbg(cls, tmp_path):
247 tmp_path = cls.popen_kwargs.get("cwd", tmp_path)
248 command, env = cls._prepare_execution(tmp_path=tmp_path)
249 print("Running the command: ", command)
250 subprocess.run(
251 command,
252 env=env,
253 cwd=tmp_path,
254 **cls.popen_kwargs,
255 )
256
257 @pytest.mark.do_not_collect_source
259 self,
260 record_property: Callable[[str, str], None],
261 fixture_result: FixtureResult,
262 reference_path: Optional[Path],
263 ) -> None:
264 """
265 Record properties and handle any failures during fixture setup.
266 """
267 for key, value in fixture_result.to_dict().items():
268 if value is not None:
269 record_property(key, value)
270 if reference_path:
271 record_property("reference_file", str(reference_path))
272
273 if fixture_result.run_exception:
274 pytest.fail(f"{fixture_result.run_exception}")
275
276 @pytest.mark.do_not_collect_source
277 def test_returncode(self, returncode: int) -> None:
278 """
279 Test that the return code matches the expected value.
280 """
281 assert returncode == self.returncode
None unset_vars(Dict[str, str] env, List[str] vars_to_unset)
None test_fixture_setup(self, Callable[[str, str], None] record_property, FixtureResult fixture_result, Optional[Path] reference_path)
str expand_vars_from(str value, Dict[str, str] env)