summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--docs/manifest-format.md41
-rw-r--r--docs/windows.md7
-rw-r--r--git_command.py6
-rw-r--r--git_config.py4
-rwxr-xr-xmain.py126
-rw-r--r--man/repo-manifest.150
-rw-r--r--manifest_xml.py11
-rw-r--r--progress.py2
-rw-r--r--project.py196
-rwxr-xr-xrelease/check-metadata.py2
-rw-r--r--run_tests.vpython316
-rw-r--r--ssh.py2
-rw-r--r--subcmds/gc.py17
-rw-r--r--subcmds/grep.py2
-rw-r--r--subcmds/help.py8
-rw-r--r--subcmds/sync.py64
-rw-r--r--subcmds/upload.py10
-rw-r--r--tests/test_color.py10
-rw-r--r--tests/test_git_config.py11
-rw-r--r--tests/test_git_trace2_event_log.py755
-rw-r--r--tests/test_main.py166
-rw-r--r--tests/test_manifest_xml.py1059
-rw-r--r--tests/test_project.py229
-rw-r--r--tests/test_subcmds_forall.py184
-rw-r--r--tests/test_subcmds_sync.py109
-rw-r--r--tests/test_subcmds_upload.py81
-rw-r--r--tests/test_wrapper.py657
-rw-r--r--tests/utils_for_test.py5
28 files changed, 2371 insertions, 1459 deletions
diff --git a/docs/manifest-format.md b/docs/manifest-format.md
index f0149dd80..42fb1bfe2 100644
--- a/docs/manifest-format.md
+++ b/docs/manifest-format.md
@@ -73,18 +73,19 @@ following DTD:
73 project*, 73 project*,
74 copyfile*, 74 copyfile*,
75 linkfile*)> 75 linkfile*)>
76 <!ATTLIST project name CDATA #REQUIRED> 76 <!ATTLIST project name CDATA #REQUIRED>
77 <!ATTLIST project path CDATA #IMPLIED> 77 <!ATTLIST project path CDATA #IMPLIED>
78 <!ATTLIST project remote IDREF #IMPLIED> 78 <!ATTLIST project remote IDREF #IMPLIED>
79 <!ATTLIST project revision CDATA #IMPLIED> 79 <!ATTLIST project revision CDATA #IMPLIED>
80 <!ATTLIST project dest-branch CDATA #IMPLIED> 80 <!ATTLIST project dest-branch CDATA #IMPLIED>
81 <!ATTLIST project groups CDATA #IMPLIED> 81 <!ATTLIST project groups CDATA #IMPLIED>
82 <!ATTLIST project sync-c CDATA #IMPLIED> 82 <!ATTLIST project sync-c CDATA #IMPLIED>
83 <!ATTLIST project sync-s CDATA #IMPLIED> 83 <!ATTLIST project sync-s CDATA #IMPLIED>
84 <!ATTLIST project sync-tags CDATA #IMPLIED> 84 <!ATTLIST project sync-tags CDATA #IMPLIED>
85 <!ATTLIST project upstream CDATA #IMPLIED> 85 <!ATTLIST project upstream CDATA #IMPLIED>
86 <!ATTLIST project clone-depth CDATA #IMPLIED> 86 <!ATTLIST project clone-depth CDATA #IMPLIED>
87 <!ATTLIST project force-path CDATA #IMPLIED> 87 <!ATTLIST project force-path CDATA #IMPLIED>
88 <!ATTLIST project sync-strategy CDATA #IMPLIED>
88 89
89 <!ELEMENT annotation EMPTY> 90 <!ELEMENT annotation EMPTY>
90 <!ATTLIST annotation name CDATA #REQUIRED> 91 <!ATTLIST annotation name CDATA #REQUIRED>
@@ -389,6 +390,22 @@ rather than the `name` attribute. This attribute only applies to the
389local mirrors syncing, it will be ignored when syncing the projects in a 390local mirrors syncing, it will be ignored when syncing the projects in a
390client working directory. 391client working directory.
391 392
393Attribute `sync-strategy`: Set the sync strategy used when fetching this
394project. Currently the only supported value is `stateless`. When set to
395`stateless`, repo will run a reflog expiration and aggressive garbage collection
396at the end of the sync process. This is useful for projects that contain
397large binary files and use `clone-depth="1"`, where garbage can accumulate
398as binaries are added, deleted, or modified across successive syncs.
399
400During a stateless sync, repo checks the following before cleaning up:
4011. The project does not share an object directory with other projects.
4022. The working tree is clean (no uncommitted changes, no untracked files).
4033. There are no unpushed local commits.
4044. There is no Git stash.
405
406If any of these conditions are not met, repo falls back to a standard
407sync without garbage collection.
408
392### Element extend-project 409### Element extend-project
393 410
394Modify the attributes of the named project. 411Modify the attributes of the named project.
diff --git a/docs/windows.md b/docs/windows.md
index 4282bebfd..575944503 100644
--- a/docs/windows.md
+++ b/docs/windows.md
@@ -50,8 +50,11 @@ Git worktrees (see the previous section for more info).
50Repo will use symlinks heavily internally. 50Repo will use symlinks heavily internally.
51On *NIX platforms, this isn't an issue, but Windows makes it a bit difficult. 51On *NIX platforms, this isn't an issue, but Windows makes it a bit difficult.
52 52
53There are some documents out there for how to do this, but usually the easiest 53The easiest method to allow users to create symlinks is by enabling
54answer is to run your shell as an Administrator and invoke repo/git in that. 54[Windows Developer Mode](https://learn.microsoft.com/en-us/windows/advanced-settings/developer-mode).
55
56The next easiest answer is to run your shell as an Administrator and invoke
57repo/git in that.
55 58
56This isn't a great solution, but Windows doesn't make this easy, so here we are. 59This isn't a great solution, but Windows doesn't make this easy, so here we are.
57 60
diff --git a/git_command.py b/git_command.py
index 89e8d2f44..ebba92608 100644
--- a/git_command.py
+++ b/git_command.py
@@ -47,7 +47,7 @@ logger = RepoLogger(__file__)
47 47
48 48
49class _GitCall: 49class _GitCall:
50 @functools.lru_cache(maxsize=None) 50 @functools.lru_cache(maxsize=None) # noqa: B019
51 def version_tuple(self): 51 def version_tuple(self):
52 ret = Wrapper().ParseGitVersion() 52 ret = Wrapper().ParseGitVersion()
53 if ret is None: 53 if ret is None:
@@ -95,7 +95,7 @@ def RepoSourceVersion():
95 ver = ver[1:] 95 ver = ver[1:]
96 else: 96 else:
97 ver = "unknown" 97 ver = "unknown"
98 setattr(RepoSourceVersion, "version", ver) 98 RepoSourceVersion.version = ver
99 99
100 return ver 100 return ver
101 101
@@ -611,7 +611,7 @@ class GitCommandError(GitError):
611 self.git_stderr = git_stderr 611 self.git_stderr = git_stderr
612 612
613 @property 613 @property
614 @functools.lru_cache(maxsize=None) 614 @functools.lru_cache(maxsize=None) # noqa: B019
615 def suggestion(self): 615 def suggestion(self):
616 """Returns helpful next steps for the given stderr.""" 616 """Returns helpful next steps for the given stderr."""
617 if not self.git_stderr: 617 if not self.git_stderr:
diff --git a/git_config.py b/git_config.py
index 5559657ae..e80a0322d 100644
--- a/git_config.py
+++ b/git_config.py
@@ -42,7 +42,7 @@ SYNC_STATE_PREFIX = "repo.syncstate."
42 42
43ID_RE = re.compile(r"^[0-9a-f]{40}$") 43ID_RE = re.compile(r"^[0-9a-f]{40}$")
44 44
45REVIEW_CACHE = dict() 45REVIEW_CACHE = {}
46 46
47 47
48def IsChange(rev): 48def IsChange(rev):
@@ -111,7 +111,7 @@ class GitConfig:
111 return cls(configfile=os.path.join(gitdir, "config"), defaults=defaults) 111 return cls(configfile=os.path.join(gitdir, "config"), defaults=defaults)
112 112
113 def __init__(self, configfile, defaults=None, jsonFile=None): 113 def __init__(self, configfile, defaults=None, jsonFile=None):
114 self.file = configfile 114 self.file = str(configfile)
115 self.defaults = defaults 115 self.defaults = defaults
116 self._cache_dict = None 116 self._cache_dict = None
117 self._section_dict = None 117 self._section_dict = None
diff --git a/main.py b/main.py
index c55922255..86a01a62a 100755
--- a/main.py
+++ b/main.py
@@ -19,6 +19,7 @@ People shouldn't run this directly; instead, they should use the `repo` wrapper
19which takes care of execing this entry point. 19which takes care of execing this entry point.
20""" 20"""
21 21
22import difflib
22import getpass 23import getpass
23import json 24import json
24import netrc 25import netrc
@@ -29,6 +30,7 @@ import signal
29import sys 30import sys
30import textwrap 31import textwrap
31import time 32import time
33from typing import Optional
32import urllib.request 34import urllib.request
33 35
34from repo_logging import RepoLogger 36from repo_logging import RepoLogger
@@ -292,6 +294,102 @@ class _Repo:
292 result = run() 294 result = run()
293 return result 295 return result
294 296
297 def _autocorrect_command_name(
298 self, name: str, config: RepoConfig
299 ) -> Optional[str]:
300 """Autocorrect command name based on user's git config."""
301
302 close_commands = difflib.get_close_matches(
303 name, self.commands.keys(), n=5, cutoff=0.7
304 )
305
306 if not close_commands:
307 logger.error(
308 "repo: '%s' is not a repo command. See 'repo help'.", name
309 )
310 return None
311
312 assumed = close_commands[0]
313 autocorrect = config.GetString("help.autocorrect")
314
315 # If there are multiple close matches, git won't automatically run one.
316 # We'll always prompt instead of guessing.
317 if len(close_commands) > 1:
318 autocorrect = "prompt"
319
320 # Handle git configuration boolean values:
321 # 0, "false", "off", "no", "show": show suggestion (default)
322 # 1, "true", "on", "yes", "immediate": run suggestion immediately
323 # "never": don't run or show any suggested command
324 # "prompt": show the suggestion and prompt for confirmation
325 # positive number > 1: run suggestion after specified deciseconds
326 if autocorrect is None:
327 autocorrect = "0"
328
329 autocorrect = autocorrect.lower()
330
331 if autocorrect in ("0", "false", "off", "no", "show"):
332 autocorrect = 0
333 elif autocorrect in ("true", "on", "yes", "immediate"):
334 autocorrect = -1 # immediate
335 elif autocorrect == "never":
336 return None
337 elif autocorrect == "prompt":
338 logger.warning(
339 "You called a repo command named "
340 "'%s', which does not exist.",
341 name,
342 )
343 try:
344 resp = input(f"Run '{assumed}' instead [y/N]? ")
345 if resp.lower().startswith("y"):
346 return assumed
347 except (KeyboardInterrupt, EOFError):
348 pass
349 return None
350 else:
351 try:
352 autocorrect = int(autocorrect)
353 except ValueError:
354 autocorrect = 0
355
356 if autocorrect != 0:
357 if autocorrect < 0:
358 logger.warning(
359 "You called a repo command named "
360 "'%s', which does not exist.\n"
361 "Continuing assuming that "
362 "you meant '%s'.",
363 name,
364 assumed,
365 )
366 else:
367 delay = autocorrect * 0.1
368 logger.warning(
369 "You called a repo command named "
370 "'%s', which does not exist.\n"
371 "Continuing in %.1f seconds, assuming "
372 "that you meant '%s'.",
373 name,
374 delay,
375 assumed,
376 )
377 try:
378 time.sleep(delay)
379 except KeyboardInterrupt:
380 return None
381 return assumed
382
383 logger.error(
384 "repo: '%s' is not a repo command. See 'repo help'.", name
385 )
386 logger.warning(
387 "The most similar command%s\n\t%s",
388 "s are" if len(close_commands) > 1 else " is",
389 "\n\t".join(close_commands),
390 )
391 return None
392
295 def _RunLong(self, name, gopts, argv, git_trace2_event_log): 393 def _RunLong(self, name, gopts, argv, git_trace2_event_log):
296 """Execute the (longer running) requested subcommand.""" 394 """Execute the (longer running) requested subcommand."""
297 result = 0 395 result = 0
@@ -306,20 +404,22 @@ class _Repo:
306 outer_client=outer_client, 404 outer_client=outer_client,
307 ) 405 )
308 406
309 try: 407 if name not in self.commands:
310 cmd = self.commands[name]( 408 corrected_name = self._autocorrect_command_name(
311 repodir=self.repodir, 409 name, outer_client.globalConfig
312 client=repo_client,
313 manifest=repo_client.manifest,
314 outer_client=outer_client,
315 outer_manifest=outer_client.manifest,
316 git_event_log=git_trace2_event_log,
317 )
318 except KeyError:
319 logger.error(
320 "repo: '%s' is not a repo command. See 'repo help'.", name
321 ) 410 )
322 return 1 411 if not corrected_name:
412 return 1
413 name = corrected_name
414
415 cmd = self.commands[name](
416 repodir=self.repodir,
417 client=repo_client,
418 manifest=repo_client.manifest,
419 outer_client=outer_client,
420 outer_manifest=outer_client.manifest,
421 git_event_log=git_trace2_event_log,
422 )
323 423
324 Editor.globalConfig = cmd.client.globalConfig 424 Editor.globalConfig = cmd.client.globalConfig
325 425
diff --git a/man/repo-manifest.1 b/man/repo-manifest.1
index 4d74fde89..75c9fa9e1 100644
--- a/man/repo-manifest.1
+++ b/man/repo-manifest.1
@@ -1,5 +1,5 @@
1.\" DO NOT MODIFY THIS FILE! It was generated by help2man. 1.\" DO NOT MODIFY THIS FILE! It was generated by help2man.
2.TH REPO "1" "March 2026" "repo manifest" "Repo Manual" 2.TH REPO "1" "April 2026" "repo manifest" "Repo Manual"
3.SH NAME 3.SH NAME
4repo \- repo manifest - manual page for repo manifest 4repo \- repo manifest - manual page for repo manifest
5.SH SYNOPSIS 5.SH SYNOPSIS
@@ -165,15 +165,32 @@ IDREF #IMPLIED>
165.TP 165.TP
166<!ATTLIST project revision 166<!ATTLIST project revision
167CDATA #IMPLIED> 167CDATA #IMPLIED>
168.TP
169<!ATTLIST project dest\-branch
170CDATA #IMPLIED>
171.TP
172<!ATTLIST project groups
173CDATA #IMPLIED>
174.TP
175<!ATTLIST project sync\-c
176CDATA #IMPLIED>
177.TP
178<!ATTLIST project sync\-s
179CDATA #IMPLIED>
180.TP
181<!ATTLIST project sync\-tags
182CDATA #IMPLIED>
183.TP
184<!ATTLIST project upstream
185CDATA #IMPLIED>
186.TP
187<!ATTLIST project clone\-depth
188CDATA #IMPLIED>
189.TP
190<!ATTLIST project force\-path
191CDATA #IMPLIED>
168.IP 192.IP
169<!ATTLIST project dest\-branch CDATA #IMPLIED> 193<!ATTLIST project sync\-strategy CDATA #IMPLIED>
170<!ATTLIST project groups CDATA #IMPLIED>
171<!ATTLIST project sync\-c CDATA #IMPLIED>
172<!ATTLIST project sync\-s CDATA #IMPLIED>
173<!ATTLIST project sync\-tags CDATA #IMPLIED>
174<!ATTLIST project upstream CDATA #IMPLIED>
175<!ATTLIST project clone\-depth CDATA #IMPLIED>
176<!ATTLIST project force\-path CDATA #IMPLIED>
177.IP 194.IP
178<!ELEMENT annotation EMPTY> 195<!ELEMENT annotation EMPTY>
179<!ATTLIST annotation name CDATA #REQUIRED> 196<!ATTLIST annotation name CDATA #REQUIRED>
@@ -469,6 +486,21 @@ mirror repository according to its `path` attribute (if supplied) rather than
469the `name` attribute. This attribute only applies to the local mirrors syncing, 486the `name` attribute. This attribute only applies to the local mirrors syncing,
470it will be ignored when syncing the projects in a client working directory. 487it will be ignored when syncing the projects in a client working directory.
471.PP 488.PP
489Attribute `sync\-strategy`: Set the sync strategy used when fetching this
490project. Currently the only supported value is `stateless`. When set to
491`stateless`, repo will run a reflog expiration and aggressive garbage collection
492at the end of the sync process. This is useful for projects that contain large
493binary files and use `clone\-depth="1"`, where garbage can accumulate as binaries
494are added, deleted, or modified across successive syncs.
495.PP
496During a stateless sync, repo checks the following before cleaning up: 1. The
497project does not share an object directory with other projects. 2. The working
498tree is clean (no uncommitted changes, no untracked files). 3. There are no
499unpushed local commits. 4. There is no Git stash.
500.PP
501If any of these conditions are not met, repo falls back to a standard sync
502without garbage collection.
503.PP
472Element extend\-project 504Element extend\-project
473.PP 505.PP
474Modify the attributes of the named project. 506Modify the attributes of the named project.
diff --git a/manifest_xml.py b/manifest_xml.py
index 084ca5ab2..5dc9d2fe9 100644
--- a/manifest_xml.py
+++ b/manifest_xml.py
@@ -759,14 +759,17 @@ https://gerrit.googlesource.com/git-repo/+/HEAD/docs/manifest-format.md
759 if p.clone_depth: 759 if p.clone_depth:
760 e.setAttribute("clone-depth", str(p.clone_depth)) 760 e.setAttribute("clone-depth", str(p.clone_depth))
761 761
762 if p.sync_strategy:
763 e.setAttribute("sync-strategy", str(p.sync_strategy))
764
762 self._output_manifest_project_extras(p, e) 765 self._output_manifest_project_extras(p, e)
763 766
764 if p.subprojects: 767 if p.subprojects:
765 subprojects = {subp.name for subp in p.subprojects} 768 subprojects = {subp.name for subp in p.subprojects}
766 output_projects(p, e, list(sorted(subprojects))) 769 output_projects(p, e, sorted(subprojects))
767 770
768 projects = {p.name for p in self._paths.values() if not p.parent} 771 projects = {p.name for p in self._paths.values() if not p.parent}
769 output_projects(None, root, list(sorted(projects))) 772 output_projects(None, root, sorted(projects))
770 773
771 if self._repo_hooks_project: 774 if self._repo_hooks_project:
772 root.appendChild(doc.createTextNode("")) 775 root.appendChild(doc.createTextNode(""))
@@ -823,7 +826,6 @@ https://gerrit.googlesource.com/git-repo/+/HEAD/docs/manifest-format.md
823 "submanifest", 826 "submanifest",
824 # These are children of 'project' nodes. 827 # These are children of 'project' nodes.
825 "annotation", 828 "annotation",
826 "project",
827 "copyfile", 829 "copyfile",
828 "linkfile", 830 "linkfile",
829 } 831 }
@@ -1939,6 +1941,8 @@ https://gerrit.googlesource.com/git-repo/+/HEAD/docs/manifest-format.md
1939 % (self.manifestFile, clone_depth) 1941 % (self.manifestFile, clone_depth)
1940 ) 1942 )
1941 1943
1944 sync_strategy = node.getAttribute("sync-strategy") or None
1945
1942 dest_branch = ( 1946 dest_branch = (
1943 node.getAttribute("dest-branch") or self._default.destBranchExpr 1947 node.getAttribute("dest-branch") or self._default.destBranchExpr
1944 ) 1948 )
@@ -1985,6 +1989,7 @@ https://gerrit.googlesource.com/git-repo/+/HEAD/docs/manifest-format.md
1985 sync_s=sync_s, 1989 sync_s=sync_s,
1986 sync_tags=sync_tags, 1990 sync_tags=sync_tags,
1987 clone_depth=clone_depth, 1991 clone_depth=clone_depth,
1992 sync_strategy=sync_strategy,
1988 upstream=upstream, 1993 upstream=upstream,
1989 parent=parent, 1994 parent=parent,
1990 dest_branch=dest_branch, 1995 dest_branch=dest_branch,
diff --git a/progress.py b/progress.py
index 9a91dcd65..2c0fead5a 100644
--- a/progress.py
+++ b/progress.py
@@ -159,6 +159,8 @@ class Progress:
159 inc: The number of items completed. 159 inc: The number of items completed.
160 msg: The message to display. If None, use the last message. 160 msg: The message to display. If None, use the last message.
161 """ 161 """
162 if self._ended:
163 return
162 self._done += inc 164 self._done += inc
163 if msg is None: 165 if msg is None:
164 msg = self._last_msg 166 msg = self._last_msg
diff --git a/project.py b/project.py
index caeaa5211..031440413 100644
--- a/project.py
+++ b/project.py
@@ -28,7 +28,7 @@ import sys
28import tarfile 28import tarfile
29import tempfile 29import tempfile
30import time 30import time
31from typing import List, NamedTuple 31from typing import List, NamedTuple, Optional
32import urllib.parse 32import urllib.parse
33 33
34from color import Coloring 34from color import Coloring
@@ -225,7 +225,7 @@ class ReviewableBranch:
225 225
226 @property 226 @property
227 def unabbrev_commits(self): 227 def unabbrev_commits(self):
228 r = dict() 228 r = {}
229 for commit in self.project.bare_git.rev_list( 229 for commit in self.project.bare_git.rev_list(
230 not_rev(self.base), R_HEADS + self.name, "--" 230 not_rev(self.base), R_HEADS + self.name, "--"
231 ): 231 ):
@@ -553,11 +553,12 @@ class Project:
553 revisionExpr, 553 revisionExpr,
554 revisionId, 554 revisionId,
555 rebase=True, 555 rebase=True,
556 groups=set(), 556 groups=None,
557 sync_c=False, 557 sync_c=False,
558 sync_s=False, 558 sync_s=False,
559 sync_tags=True, 559 sync_tags=True,
560 clone_depth=None, 560 clone_depth=None,
561 sync_strategy=None,
561 upstream=None, 562 upstream=None,
562 parent=None, 563 parent=None,
563 use_git_worktrees=False, 564 use_git_worktrees=False,
@@ -605,11 +606,12 @@ class Project:
605 self.SetRevision(revisionExpr, revisionId=revisionId) 606 self.SetRevision(revisionExpr, revisionId=revisionId)
606 607
607 self.rebase = rebase 608 self.rebase = rebase
608 self.groups = groups 609 self.groups = groups if groups is not None else set()
609 self.sync_c = sync_c 610 self.sync_c = sync_c
610 self.sync_s = sync_s 611 self.sync_s = sync_s
611 self.sync_tags = sync_tags 612 self.sync_tags = sync_tags
612 self.clone_depth = clone_depth 613 self.clone_depth = clone_depth
614 self.sync_strategy = sync_strategy
613 self.upstream = upstream 615 self.upstream = upstream
614 self.parent = parent 616 self.parent = parent
615 # NB: Do not use this setting in __init__ to change behavior so that the 617 # NB: Do not use this setting in __init__ to change behavior so that the
@@ -627,6 +629,7 @@ class Project:
627 self.linkfiles = {} 629 self.linkfiles = {}
628 self.annotations = [] 630 self.annotations = []
629 self.dest_branch = dest_branch 631 self.dest_branch = dest_branch
632 self.stateless_prune_needed = False
630 633
631 # This will be filled in if a project is later identified to be the 634 # This will be filled in if a project is later identified to be the
632 # project containing repo hooks. 635 # project containing repo hooks.
@@ -756,6 +759,18 @@ class Project:
756 return True 759 return True
757 return False 760 return False
758 761
762 def HasStash(self) -> bool:
763 """Returns True if there is a stash in the repository."""
764 p = GitCommand(
765 self,
766 ["rev-parse", "--verify", "refs/stash"],
767 bare=True,
768 capture_stdout=True,
769 capture_stderr=True,
770 log_as_error=False,
771 )
772 return p.Wait() == 0
773
759 _userident_name = None 774 _userident_name = None
760 _userident_email = None 775 _userident_email = None
761 776
@@ -943,7 +958,7 @@ class Project:
943 out.important("prior sync failed; rebase still in progress") 958 out.important("prior sync failed; rebase still in progress")
944 out.nl() 959 out.nl()
945 960
946 paths = list() 961 paths = []
947 paths.extend(di.keys()) 962 paths.extend(di.keys())
948 paths.extend(df.keys()) 963 paths.extend(df.keys())
949 paths.extend(do) 964 paths.extend(do)
@@ -1239,12 +1254,74 @@ class Project:
1239 logger.error("error: Cannot extract archive %s: %s", tarpath, e) 1254 logger.error("error: Cannot extract archive %s: %s", tarpath, e)
1240 return False 1255 return False
1241 1256
1257 def _ShouldStatelessPrune(
1258 self, use_superproject: Optional[bool] = None
1259 ) -> bool:
1260 """Determines if a stateless prune should be performed.
1261
1262 Stateless pruning reclaims space by running a reflog expiration and
1263 garbage collection instead of an incremental fetch. It is only performed
1264 if the repository is clean and has no local-only state.
1265 """
1266 if not self.Exists:
1267 return False
1268
1269 if self._CheckForImmutableRevision(use_superproject=use_superproject):
1270 return False
1271
1272 # Query the target hash from remote to see if we are up-to-date.
1273 target_hash = None
1274 if IsId(self.revisionExpr):
1275 target_hash = self.revisionExpr
1276 else:
1277 output = self._LsRemote(self.upstream or self.revisionExpr)
1278 if output:
1279 target_hash = output.splitlines()[0].split()[0]
1280
1281 if not target_hash:
1282 return False
1283
1284 try:
1285 local_head = self.bare_git.rev_parse("HEAD")
1286 except GitError:
1287 local_head = None
1288
1289 if target_hash == local_head:
1290 return False
1291
1292 # Skip if sharing objects with other projects.
1293 shares_objdir = self.UseAlternates or self.use_git_worktrees
1294 if not shares_objdir:
1295 for p in self.manifest.GetProjectsWithName(self.name):
1296 if p != self and p.objdir == self.objdir:
1297 shares_objdir = True
1298 break
1299
1300 if shares_objdir:
1301 return False
1302
1303 # Skip if HEAD contains any unpushed local commits.
1304 try:
1305 local_commits = self.bare_git.rev_list(
1306 "--count", "HEAD", "--not", "--remotes", "--tags"
1307 )
1308 if int(local_commits[0]) > 0:
1309 return False
1310 except (GitError, IndexError, ValueError):
1311 return False
1312
1313 if self.IsDirty(consider_untracked=True) or self.HasStash():
1314 return False
1315
1316 return True
1317
1242 def Sync_NetworkHalf( 1318 def Sync_NetworkHalf(
1243 self, 1319 self,
1244 quiet=False, 1320 quiet=False,
1245 verbose=False, 1321 verbose=False,
1246 output_redir=None, 1322 output_redir=None,
1247 is_new=None, 1323 is_new=None,
1324 use_superproject=None,
1248 current_branch_only=None, 1325 current_branch_only=None,
1249 force_sync=False, 1326 force_sync=False,
1250 clone_bundle=True, 1327 clone_bundle=True,
@@ -1256,7 +1333,7 @@ class Project:
1256 submodules=False, 1333 submodules=False,
1257 ssh_proxy=None, 1334 ssh_proxy=None,
1258 clone_filter=None, 1335 clone_filter=None,
1259 partial_clone_exclude=set(), 1336 partial_clone_exclude=None,
1260 clone_filter_for_depth=None, 1337 clone_filter_for_depth=None,
1261 ): 1338 ):
1262 """Perform only the network IO portion of the sync process. 1339 """Perform only the network IO portion of the sync process.
@@ -1309,10 +1386,17 @@ class Project:
1309 if clone_bundle and os.path.exists(self.objdir): 1386 if clone_bundle and os.path.exists(self.objdir):
1310 clone_bundle = False 1387 clone_bundle = False
1311 1388
1389 if partial_clone_exclude is None:
1390 partial_clone_exclude = set()
1312 if self.name in partial_clone_exclude: 1391 if self.name in partial_clone_exclude:
1313 clone_bundle = True 1392 clone_bundle = True
1314 clone_filter = None 1393 clone_filter = None
1315 1394
1395 if self.sync_strategy == "stateless" and self._ShouldStatelessPrune(
1396 use_superproject
1397 ):
1398 self.stateless_prune_needed = True
1399
1316 if is_new is None: 1400 if is_new is None:
1317 is_new = not self.Exists 1401 is_new = not self.Exists
1318 if is_new: 1402 if is_new:
@@ -1390,6 +1474,15 @@ class Project:
1390 else: 1474 else:
1391 depth = self.manifest.manifestProject.depth 1475 depth = self.manifest.manifestProject.depth
1392 1476
1477 # If the project has been manually unshallowed (e.g. via
1478 # `git fetch --unshallow`), don't re-shallow it during sync.
1479 if (
1480 depth
1481 and not is_new
1482 and not os.path.exists(os.path.join(self.gitdir, "shallow"))
1483 ):
1484 depth = None
1485
1393 if depth and clone_filter_for_depth: 1486 if depth and clone_filter_for_depth:
1394 depth = None 1487 depth = None
1395 clone_filter = clone_filter_for_depth 1488 clone_filter = clone_filter_for_depth
@@ -1399,7 +1492,13 @@ class Project:
1399 if not ( 1492 if not (
1400 optimized_fetch 1493 optimized_fetch
1401 and IsId(self.revisionExpr) 1494 and IsId(self.revisionExpr)
1402 and self._CheckForImmutableRevision() 1495 and self._CheckForImmutableRevision(
1496 use_superproject=use_superproject
1497 )
1498 and (
1499 not depth
1500 or os.path.exists(os.path.join(self.gitdir, "shallow"))
1501 )
1403 ): 1502 ):
1404 remote_fetched = True 1503 remote_fetched = True
1405 try: 1504 try:
@@ -1409,6 +1508,7 @@ class Project:
1409 verbose=verbose, 1508 verbose=verbose,
1410 output_redir=output_redir, 1509 output_redir=output_redir,
1411 alt_dir=alt_dir, 1510 alt_dir=alt_dir,
1511 use_superproject=use_superproject,
1412 current_branch_only=current_branch_only, 1512 current_branch_only=current_branch_only,
1413 tags=tags, 1513 tags=tags,
1414 prune=prune, 1514 prune=prune,
@@ -1585,6 +1685,23 @@ class Project:
1585 def _dosubmodules(): 1685 def _dosubmodules():
1586 self._SyncSubmodules(quiet=True) 1686 self._SyncSubmodules(quiet=True)
1587 1687
1688 def _doprune() -> None:
1689 """Expire reflogs and run prune-now GC for stateless sync."""
1690 GitCommand(
1691 self,
1692 ["reflog", "expire", "--expire=all", "--all"],
1693 bare=True,
1694 ).Wait()
1695 p = GitCommand(
1696 self,
1697 ["gc", "--prune=now"],
1698 bare=True,
1699 capture_stdout=True,
1700 capture_stderr=True,
1701 )
1702 if p.Wait() != 0:
1703 logger.warning("warn: %s: stateless gc failed", self.name)
1704
1588 head = self.work_git.GetHead() 1705 head = self.work_git.GetHead()
1589 if head.startswith(R_HEADS): 1706 if head.startswith(R_HEADS):
1590 branch = head[len(R_HEADS) :] 1707 branch = head[len(R_HEADS) :]
@@ -1630,6 +1747,8 @@ class Project:
1630 fail(e) 1747 fail(e)
1631 return 1748 return
1632 self._CopyAndLinkFiles() 1749 self._CopyAndLinkFiles()
1750 if self.stateless_prune_needed:
1751 syncbuf.later2(self, _doprune, not verbose)
1633 return 1752 return
1634 1753
1635 if head == revid: 1754 if head == revid:
@@ -1776,6 +1895,9 @@ class Project:
1776 if submodules: 1895 if submodules:
1777 syncbuf.later1(self, _dosubmodules, not verbose) 1896 syncbuf.later1(self, _dosubmodules, not verbose)
1778 1897
1898 if self.stateless_prune_needed:
1899 syncbuf.later2(self, _doprune, not verbose)
1900
1779 def AddCopyFile(self, src, dest, topdir): 1901 def AddCopyFile(self, src, dest, topdir):
1780 """Mark |src| for copying to |dest| (relative to |topdir|). 1902 """Mark |src| for copying to |dest| (relative to |topdir|).
1781 1903
@@ -2397,7 +2519,9 @@ class Project:
2397 2519
2398 return None 2520 return None
2399 2521
2400 def _CheckForImmutableRevision(self): 2522 def _CheckForImmutableRevision(
2523 self, use_superproject: Optional[bool] = None
2524 ) -> bool:
2401 try: 2525 try:
2402 # if revision (sha or tag) is not present then following function 2526 # if revision (sha or tag) is not present then following function
2403 # throws an error. 2527 # throws an error.
@@ -2405,7 +2529,9 @@ class Project:
2405 upstream_rev = None 2529 upstream_rev = None
2406 2530
2407 # Only check upstream when using superproject. 2531 # Only check upstream when using superproject.
2408 if self.upstream and self.manifest.manifestProject.use_superproject: 2532 if self.upstream and git_superproject.UseSuperproject(
2533 use_superproject, self.manifest
2534 ):
2409 upstream_rev = self.GetRemote().ToLocal(self.upstream) 2535 upstream_rev = self.GetRemote().ToLocal(self.upstream)
2410 revs.append(upstream_rev) 2536 revs.append(upstream_rev)
2411 2537
@@ -2419,7 +2545,9 @@ class Project:
2419 2545
2420 # Only verify upstream relationship for superproject scenarios 2546 # Only verify upstream relationship for superproject scenarios
2421 # without affecting plain usage. 2547 # without affecting plain usage.
2422 if self.upstream and self.manifest.manifestProject.use_superproject: 2548 if self.upstream and git_superproject.UseSuperproject(
2549 use_superproject, self.manifest
2550 ):
2423 self.bare_git.merge_base( 2551 self.bare_git.merge_base(
2424 "--is-ancestor", 2552 "--is-ancestor",
2425 self.revisionExpr, 2553 self.revisionExpr,
@@ -2450,6 +2578,7 @@ class Project:
2450 def _RemoteFetch( 2578 def _RemoteFetch(
2451 self, 2579 self,
2452 name=None, 2580 name=None,
2581 use_superproject=None,
2453 current_branch_only=False, 2582 current_branch_only=False,
2454 initial=False, 2583 initial=False,
2455 quiet=False, 2584 quiet=False,
@@ -2489,7 +2618,12 @@ class Project:
2489 tag_name = self.upstream[len(R_TAGS) :] 2618 tag_name = self.upstream[len(R_TAGS) :]
2490 2619
2491 if is_sha1 or tag_name is not None: 2620 if is_sha1 or tag_name is not None:
2492 if self._CheckForImmutableRevision(): 2621 if self._CheckForImmutableRevision(
2622 use_superproject=use_superproject
2623 ) and (
2624 not depth
2625 or os.path.exists(os.path.join(self.gitdir, "shallow"))
2626 ):
2493 if verbose: 2627 if verbose:
2494 print( 2628 print(
2495 "Skipped fetching project %s (already have " 2629 "Skipped fetching project %s (already have "
@@ -2546,7 +2680,7 @@ class Project:
2546 if update_ref_cmds: 2680 if update_ref_cmds:
2547 GitCommand( 2681 GitCommand(
2548 self, 2682 self,
2549 ["update-ref", "--no-deref", "--stdin"], 2683 ["update-ref", "--stdin"],
2550 bare=True, 2684 bare=True,
2551 input="".join(update_ref_cmds), 2685 input="".join(update_ref_cmds),
2552 ).Wait() 2686 ).Wait()
@@ -2794,7 +2928,7 @@ class Project:
2794 ) 2928 )
2795 GitCommand( 2929 GitCommand(
2796 self, 2930 self,
2797 ["update-ref", "--no-deref", "--stdin"], 2931 ["update-ref", "--stdin"],
2798 bare=True, 2932 bare=True,
2799 input=delete_cmds, 2933 input=delete_cmds,
2800 log_as_error=False, 2934 log_as_error=False,
@@ -2809,7 +2943,9 @@ class Project:
2809 # We just synced the upstream given branch; verify we 2943 # We just synced the upstream given branch; verify we
2810 # got what we wanted, else trigger a second run of all 2944 # got what we wanted, else trigger a second run of all
2811 # refs. 2945 # refs.
2812 if not self._CheckForImmutableRevision(): 2946 if not self._CheckForImmutableRevision(
2947 use_superproject=use_superproject
2948 ):
2813 # Sync the current branch only with depth set to None. 2949 # Sync the current branch only with depth set to None.
2814 # We always pass depth=None down to avoid infinite recursion. 2950 # We always pass depth=None down to avoid infinite recursion.
2815 return self._RemoteFetch( 2951 return self._RemoteFetch(
@@ -2817,6 +2953,7 @@ class Project:
2817 quiet=quiet, 2953 quiet=quiet,
2818 verbose=verbose, 2954 verbose=verbose,
2819 output_redir=output_redir, 2955 output_redir=output_redir,
2956 use_superproject=use_superproject,
2820 current_branch_only=current_branch_only and depth, 2957 current_branch_only=current_branch_only and depth,
2821 initial=False, 2958 initial=False,
2822 alt_dir=alt_dir, 2959 alt_dir=alt_dir,
@@ -3939,30 +4076,14 @@ class Project:
3939 def GetHead(self): 4076 def GetHead(self):
3940 """Return the ref that HEAD points to.""" 4077 """Return the ref that HEAD points to."""
3941 try: 4078 try:
3942 symbolic_head = self.rev_parse("--symbolic-full-name", HEAD) 4079 return self.symbolic_ref("-q", HEAD, log_as_error=False)
3943 if symbolic_head == HEAD: 4080 except GitError:
3944 # Detached HEAD. Return the commit SHA instead. 4081 pass
3945 return self.rev_parse(HEAD)
3946 return symbolic_head
3947 except GitError as e:
3948 # `git rev-parse --symbolic-full-name HEAD` will fail for unborn
3949 # branches, so try symbolic-ref before falling back to raw file
3950 # parsing.
3951 try:
3952 p = GitCommand(
3953 self._project,
3954 ["symbolic-ref", "-q", HEAD],
3955 bare=True,
3956 gitdir=self._gitdir,
3957 capture_stdout=True,
3958 capture_stderr=True,
3959 log_as_error=False,
3960 )
3961 if p.Wait() == 0:
3962 return p.stdout.rstrip("\n")
3963 except GitError:
3964 pass
3965 4082
4083 try:
4084 # If symbolic-ref fails, try to treat as detached HEAD.
4085 return self.rev_parse(HEAD)
4086 except GitError as e:
3966 logger.warning( 4087 logger.warning(
3967 "project %s: unparseable HEAD; trying to recover.\n" 4088 "project %s: unparseable HEAD; trying to recover.\n"
3968 "Check that HEAD ref in .git/HEAD is valid. The error " 4089 "Check that HEAD ref in .git/HEAD is valid. The error "
@@ -4827,6 +4948,7 @@ class ManifestProject(MetaProject):
4827 quiet=not verbose, 4948 quiet=not verbose,
4828 verbose=verbose, 4949 verbose=verbose,
4829 clone_bundle=clone_bundle, 4950 clone_bundle=clone_bundle,
4951 use_superproject=use_superproject,
4830 current_branch_only=current_branch_only, 4952 current_branch_only=current_branch_only,
4831 tags=tags, 4953 tags=tags,
4832 submodules=submodules, 4954 submodules=submodules,
diff --git a/release/check-metadata.py b/release/check-metadata.py
index 951bd4c2c..9b9a347c8 100755
--- a/release/check-metadata.py
+++ b/release/check-metadata.py
@@ -106,7 +106,7 @@ def check_path(opts: argparse.Namespace, path: Path) -> bool:
106def check_paths(opts: argparse.Namespace, paths: list[Path]) -> bool: 106def check_paths(opts: argparse.Namespace, paths: list[Path]) -> bool:
107 """Check all the paths.""" 107 """Check all the paths."""
108 # NB: Use list comprehension and not a generator so we check all paths. 108 # NB: Use list comprehension and not a generator so we check all paths.
109 return all([check_path(opts, x) for x in paths]) 109 return all([check_path(opts, x) for x in paths]) # noqa: C419
110 110
111 111
112def find_files(opts: argparse.Namespace) -> list[Path]: 112def find_files(opts: argparse.Namespace) -> list[Path]:
diff --git a/run_tests.vpython3 b/run_tests.vpython3
index e6dfe7c63..e07d08256 100644
--- a/run_tests.vpython3
+++ b/run_tests.vpython3
@@ -48,10 +48,10 @@ wheel: <
48 version: "version:3.0.7" 48 version: "version:3.0.7"
49> 49>
50 50
51# Required by pytest==8.3.4 51# Required by pytest==8.3.4 and flake8-bugbear==24.12.12
52wheel: < 52wheel: <
53 name: "infra/python/wheels/attrs-py2_py3" 53 name: "infra/python/wheels/attrs-py3"
54 version: "version:21.4.0" 54 version: "version:24.2.0"
55> 55>
56 56
57# NB: Keep in sync with constraints.txt. 57# NB: Keep in sync with constraints.txt.
@@ -120,6 +120,16 @@ wheel: <
120> 120>
121 121
122wheel: < 122wheel: <
123 name: "infra/python/wheels/flake8-bugbear-py3"
124 version: "version:24.12.12"
125>
126
127wheel: <
128 name: "infra/python/wheels/flake8-comprehensions-py3"
129 version: "version:3.16.0"
130>
131
132wheel: <
123 name: "infra/python/wheels/isort-py3" 133 name: "infra/python/wheels/isort-py3"
124 version: "version:5.10.1" 134 version: "version:5.10.1"
125> 135>
diff --git a/ssh.py b/ssh.py
index 6de8b89e8..25152f1cc 100644
--- a/ssh.py
+++ b/ssh.py
@@ -149,7 +149,7 @@ class ProxyManager:
149 while True: 149 while True:
150 try: 150 try:
151 procs.pop(0) 151 procs.pop(0)
152 except: # noqa: E722 152 except IndexError:
153 break 153 break
154 154
155 def close(self): 155 def close(self):
diff --git a/subcmds/gc.py b/subcmds/gc.py
index a23d5e068..1d1023a1f 100644
--- a/subcmds/gc.py
+++ b/subcmds/gc.py
@@ -16,6 +16,7 @@ import os
16from typing import List, Set 16from typing import List, Set
17 17
18from command import Command 18from command import Command
19from git_command import git_require
19from git_command import GitCommand 20from git_command import GitCommand
20import platform_utils 21import platform_utils
21from progress import Progress 22from progress import Progress
@@ -204,6 +205,7 @@ class Gc(Command):
204 [ 205 [
205 "rev-list", 206 "rev-list",
206 "--objects", 207 "--objects",
208 "--missing=allow-promisor",
207 f"--remotes={project.remote.name}", 209 f"--remotes={project.remote.name}",
208 "--filter=blob:none", 210 "--filter=blob:none",
209 "--tags", 211 "--tags",
@@ -215,7 +217,12 @@ class Gc(Command):
215 # Get all local objects and pack them. 217 # Get all local objects and pack them.
216 local_head_objects_cmd = GitCommand( 218 local_head_objects_cmd = GitCommand(
217 project, 219 project,
218 ["rev-list", "--objects", "HEAD^{tree}"], 220 [
221 "rev-list",
222 "--objects",
223 "--missing=allow-promisor",
224 "HEAD^{tree}",
225 ],
219 capture_stdout=True, 226 capture_stdout=True,
220 verify_command=True, 227 verify_command=True,
221 ) 228 )
@@ -224,6 +231,7 @@ class Gc(Command):
224 [ 231 [
225 "rev-list", 232 "rev-list",
226 "--objects", 233 "--objects",
234 "--missing=allow-promisor",
227 "--all", 235 "--all",
228 "--reflog", 236 "--reflog",
229 "--indexed-objects", 237 "--indexed-objects",
@@ -297,7 +305,8 @@ class Gc(Command):
297 if ret != 0: 305 if ret != 0:
298 return ret 306 return ret
299 307
300 if not opt.repack: 308 if opt.repack:
301 return 309 git_require((2, 17, 0), fail=True, msg="--repack")
310 ret = self.repack_projects(projects, opt)
302 311
303 return self.repack_projects(projects, opt) 312 return ret
diff --git a/subcmds/grep.py b/subcmds/grep.py
index 85977ce80..e0f239fd3 100644
--- a/subcmds/grep.py
+++ b/subcmds/grep.py
@@ -93,7 +93,7 @@ contain a line that matches both expressions:
93 pt = getattr(parser.values, "cmd_argv", None) 93 pt = getattr(parser.values, "cmd_argv", None)
94 if pt is None: 94 if pt is None:
95 pt = [] 95 pt = []
96 setattr(parser.values, "cmd_argv", pt) 96 parser.values.cmd_argv = pt
97 97
98 if opt_str == "-(": 98 if opt_str == "-(":
99 pt.append("(") 99 pt.append("(")
diff --git a/subcmds/help.py b/subcmds/help.py
index 800407114..df7806fa5 100644
--- a/subcmds/help.py
+++ b/subcmds/help.py
@@ -59,7 +59,7 @@ Displays detailed usage information about a command.
59 59
60 def PrintAllCommandsBody(self): 60 def PrintAllCommandsBody(self):
61 print("The complete list of recognized repo commands is:") 61 print("The complete list of recognized repo commands is:")
62 commandNames = list(sorted(all_commands)) 62 commandNames = sorted(all_commands)
63 self._PrintCommands(commandNames) 63 self._PrintCommands(commandNames)
64 print( 64 print(
65 "See 'repo help <command>' for more information on a " 65 "See 'repo help <command>' for more information on a "
@@ -74,10 +74,8 @@ Displays detailed usage information about a command.
74 def PrintCommonCommandsBody(self): 74 def PrintCommonCommandsBody(self):
75 print("The most commonly used repo commands are:") 75 print("The most commonly used repo commands are:")
76 76
77 commandNames = list( 77 commandNames = sorted(
78 sorted( 78 name for name, command in all_commands.items() if command.COMMON
79 name for name, command in all_commands.items() if command.COMMON
80 )
81 ) 79 )
82 self._PrintCommands(commandNames) 80 self._PrintCommands(commandNames)
83 81
diff --git a/subcmds/sync.py b/subcmds/sync.py
index 89b58e6aa..bfbe1937d 100644
--- a/subcmds/sync.py
+++ b/subcmds/sync.py
@@ -808,6 +808,7 @@ later is required to fix a server side protocol bug.
808 quiet=opt.quiet, 808 quiet=opt.quiet,
809 verbose=opt.verbose, 809 verbose=opt.verbose,
810 output_redir=buf, 810 output_redir=buf,
811 use_superproject=opt.use_superproject,
811 current_branch_only=cls._GetCurrentBranchOnly( 812 current_branch_only=cls._GetCurrentBranchOnly(
812 opt, project.manifest 813 opt, project.manifest
813 ), 814 ),
@@ -946,7 +947,7 @@ later is required to fix a server side protocol bug.
946 "sync_dict" 947 "sync_dict"
947 ] = multiprocessing.Manager().dict() 948 ] = multiprocessing.Manager().dict()
948 949
949 objdir_project_map = dict() 950 objdir_project_map = {}
950 for index, project in enumerate(projects): 951 for index, project in enumerate(projects):
951 objdir_project_map.setdefault(project.objdir, []).append(index) 952 objdir_project_map.setdefault(project.objdir, []).append(index)
952 projects_list = list(objdir_project_map.values()) 953 projects_list = list(objdir_project_map.values())
@@ -1243,14 +1244,15 @@ later is required to fix a server side protocol bug.
1243 1244
1244 return False 1245 return False
1245 1246
1246 def _SetPreciousObjectsState(self, project: Project, opt): 1247 @classmethod
1248 def _SetPreciousObjectsState(cls, project: Project, opt):
1247 """Correct the preciousObjects state for the project. 1249 """Correct the preciousObjects state for the project.
1248 1250
1249 Args: 1251 Args:
1250 project: the project to examine, and possibly correct. 1252 project: the project to examine, and possibly correct.
1251 opt: options given to sync. 1253 opt: options given to sync.
1252 """ 1254 """
1253 expected = self._GetPreciousObjectsState(project, opt) 1255 expected = cls._GetPreciousObjectsState(project, opt)
1254 actual = ( 1256 actual = (
1255 project.config.GetBoolean("extensions.preciousObjects") or False 1257 project.config.GetBoolean("extensions.preciousObjects") or False
1256 ) 1258 )
@@ -1284,7 +1286,22 @@ later is required to fix a server side protocol bug.
1284 project.config.SetString("extensions.preciousObjects", None) 1286 project.config.SetString("extensions.preciousObjects", None)
1285 project.config.SetString("gc.pruneExpire", None) 1287 project.config.SetString("gc.pruneExpire", None)
1286 1288
1287 def _GCProjects(self, projects, opt, err_event): 1289 @staticmethod
1290 def _RunOneGC(project: Project, config: Optional[dict] = None) -> None:
1291 """Run auto GC on a single project."""
1292 local_config = {}
1293 if config:
1294 local_config.update(config)
1295 local_config["gc.autoDetach"] = "false"
1296 project.bare_git.gc("--auto", config=local_config)
1297
1298 @classmethod
1299 def _GCProjects(
1300 cls,
1301 projects: List[Project],
1302 opt: optparse.Values,
1303 err_event: _threading.Event,
1304 ) -> None:
1288 """Perform garbage collection. 1305 """Perform garbage collection.
1289 1306
1290 If We are skipping garbage collection (opt.auto_gc not set), we still 1307 If We are skipping garbage collection (opt.auto_gc not set), we still
@@ -1294,7 +1311,7 @@ later is required to fix a server side protocol bug.
1294 if not opt.auto_gc: 1311 if not opt.auto_gc:
1295 # Just repair preciousObjects state, and return. 1312 # Just repair preciousObjects state, and return.
1296 for project in projects: 1313 for project in projects:
1297 self._SetPreciousObjectsState(project, opt) 1314 cls._SetPreciousObjectsState(project, opt)
1298 return 1315 return
1299 1316
1300 pm = Progress( 1317 pm = Progress(
@@ -1304,9 +1321,8 @@ later is required to fix a server side protocol bug.
1304 1321
1305 tidy_dirs = {} 1322 tidy_dirs = {}
1306 for project in projects: 1323 for project in projects:
1307 self._SetPreciousObjectsState(project, opt) 1324 cls._SetPreciousObjectsState(project, opt)
1308 1325
1309 project.config.SetString("gc.autoDetach", "false")
1310 # Only call git gc once per objdir, but call pack-refs for the 1326 # Only call git gc once per objdir, but call pack-refs for the
1311 # remainder. 1327 # remainder.
1312 if project.objdir not in tidy_dirs: 1328 if project.objdir not in tidy_dirs:
@@ -1327,7 +1343,7 @@ later is required to fix a server side protocol bug.
1327 pm.update(msg=bare_git._project.name) 1343 pm.update(msg=bare_git._project.name)
1328 1344
1329 if run_gc: 1345 if run_gc:
1330 bare_git.gc("--auto") 1346 cls._RunOneGC(bare_git._project)
1331 else: 1347 else:
1332 bare_git.pack_refs() 1348 bare_git.pack_refs()
1333 pm.end() 1349 pm.end()
@@ -1344,7 +1360,7 @@ later is required to fix a server side protocol bug.
1344 try: 1360 try:
1345 try: 1361 try:
1346 if run_gc: 1362 if run_gc:
1347 bare_git.gc("--auto", config=config) 1363 cls._RunOneGC(bare_git._project, config=config)
1348 else: 1364 else:
1349 bare_git.pack_refs(config=config) 1365 bare_git.pack_refs(config=config)
1350 except GitError: 1366 except GitError:
@@ -1447,7 +1463,6 @@ later is required to fix a server side protocol bug.
1447 if not projects: 1463 if not projects:
1448 return 1464 return
1449 1465
1450 bloated_projects = []
1451 pm = Progress( 1466 pm = Progress(
1452 "Checking for bloat", len(projects), delay=False, quiet=opt.quiet 1467 "Checking for bloat", len(projects), delay=False, quiet=opt.quiet
1453 ) 1468 )
@@ -1455,7 +1470,7 @@ later is required to fix a server side protocol bug.
1455 def _ProcessResults(pool, pm, results): 1470 def _ProcessResults(pool, pm, results):
1456 for result in results: 1471 for result in results:
1457 if result: 1472 if result:
1458 bloated_projects.append(result) 1473 self._bloated_projects.append(result)
1459 pm.update(msg="") 1474 pm.update(msg="")
1460 1475
1461 with self.ParallelContext(): 1476 with self.ParallelContext():
@@ -1470,15 +1485,6 @@ later is required to fix a server side protocol bug.
1470 ) 1485 )
1471 pm.end() 1486 pm.end()
1472 1487
1473 for project_name in bloated_projects:
1474 warn_msg = (
1475 f'warning: Project "{project_name}" is accumulating '
1476 'unoptimized data. Please run "repo sync --auto-gc" or '
1477 '"repo gc --repack" to clean up.'
1478 )
1479 self.git_event_log.ErrorEvent(warn_msg)
1480 logger.warning(warn_msg)
1481
1482 def _UpdateRepoProject(self, opt, manifest, errors): 1488 def _UpdateRepoProject(self, opt, manifest, errors):
1483 """Fetch the repo project and check for updates.""" 1489 """Fetch the repo project and check for updates."""
1484 if opt.local_only: 1490 if opt.local_only:
@@ -1499,6 +1505,7 @@ later is required to fix a server side protocol bug.
1499 quiet=opt.quiet, 1505 quiet=opt.quiet,
1500 verbose=opt.verbose, 1506 verbose=opt.verbose,
1501 output_redir=buf, 1507 output_redir=buf,
1508 use_superproject=opt.use_superproject,
1502 current_branch_only=self._GetCurrentBranchOnly( 1509 current_branch_only=self._GetCurrentBranchOnly(
1503 opt, manifest 1510 opt, manifest
1504 ), 1511 ),
@@ -1830,6 +1837,7 @@ later is required to fix a server side protocol bug.
1830 quiet=not opt.verbose, 1837 quiet=not opt.verbose,
1831 output_redir=buf, 1838 output_redir=buf,
1832 verbose=opt.verbose, 1839 verbose=opt.verbose,
1840 use_superproject=opt.use_superproject,
1833 current_branch_only=self._GetCurrentBranchOnly( 1841 current_branch_only=self._GetCurrentBranchOnly(
1834 opt, mp.manifest 1842 opt, mp.manifest
1835 ), 1843 ),
@@ -2094,6 +2102,7 @@ later is required to fix a server side protocol bug.
2094 2102
2095 self._fetch_times = _FetchTimes(manifest) 2103 self._fetch_times = _FetchTimes(manifest)
2096 self._local_sync_state = LocalSyncState(manifest) 2104 self._local_sync_state = LocalSyncState(manifest)
2105 self._bloated_projects = []
2097 2106
2098 if opt.interleaved: 2107 if opt.interleaved:
2099 sync_method = self._SyncInterleaved 2108 sync_method = self._SyncInterleaved
@@ -2110,6 +2119,9 @@ later is required to fix a server side protocol bug.
2110 superproject_logging_data, 2119 superproject_logging_data,
2111 ) 2120 )
2112 2121
2122 if not opt.quiet:
2123 print("Finalizing sync state...")
2124
2113 # Log the previous sync analysis state from the config. 2125 # Log the previous sync analysis state from the config.
2114 self.git_event_log.LogDataConfigEvents( 2126 self.git_event_log.LogDataConfigEvents(
2115 mp.config.GetSyncAnalysisStateData(), "previous_sync_state" 2127 mp.config.GetSyncAnalysisStateData(), "previous_sync_state"
@@ -2131,6 +2143,15 @@ later is required to fix a server side protocol bug.
2131 if existing: 2143 if existing:
2132 self._CheckForBloatedProjects(all_projects, opt) 2144 self._CheckForBloatedProjects(all_projects, opt)
2133 2145
2146 for project_name in sorted(self._bloated_projects):
2147 warn_msg = (
2148 f'warning: Project "{project_name}" is accumulating '
2149 'unoptimized data. Please run "repo sync --auto-gc" or '
2150 '"repo gc --repack" to clean up.'
2151 )
2152 self.git_event_log.ErrorEvent(warn_msg)
2153 logger.warning(warn_msg)
2154
2134 if not opt.quiet: 2155 if not opt.quiet:
2135 print("repo sync has finished successfully.") 2156 print("repo sync has finished successfully.")
2136 2157
@@ -2356,6 +2377,7 @@ later is required to fix a server side protocol bug.
2356 quiet=opt.quiet, 2377 quiet=opt.quiet,
2357 verbose=opt.verbose, 2378 verbose=opt.verbose,
2358 output_redir=network_output_capture, 2379 output_redir=network_output_capture,
2380 use_superproject=opt.use_superproject,
2359 current_branch_only=cls._GetCurrentBranchOnly( 2381 current_branch_only=cls._GetCurrentBranchOnly(
2360 opt, project.manifest 2382 opt, project.manifest
2361 ), 2383 ),
@@ -2653,7 +2675,7 @@ later is required to fix a server side protocol bug.
2653 if previously_pending_relpaths == pending_relpaths: 2675 if previously_pending_relpaths == pending_relpaths:
2654 stalled_projects_str = "\n".join( 2676 stalled_projects_str = "\n".join(
2655 f" - {path}" 2677 f" - {path}"
2656 for path in sorted(list(pending_relpaths)) 2678 for path in sorted(pending_relpaths)
2657 ) 2679 )
2658 logger.error( 2680 logger.error(
2659 "The following projects failed and could " 2681 "The following projects failed and could "
diff --git a/subcmds/upload.py b/subcmds/upload.py
index 4f817ddfb..49d8e2e5f 100644
--- a/subcmds/upload.py
+++ b/subcmds/upload.py
@@ -27,6 +27,7 @@ from error import SilentRepoExitError
27from error import UploadError 27from error import UploadError
28from git_command import GitCommand 28from git_command import GitCommand
29from git_refs import R_HEADS 29from git_refs import R_HEADS
30import git_superproject
30from hooks import RepoHook 31from hooks import RepoHook
31from project import ReviewableBranch 32from project import ReviewableBranch
32from repo_logging import RepoLogger 33from repo_logging import RepoLogger
@@ -627,7 +628,7 @@ Gerrit Code Review: https://www.gerritcodereview.com/
627 # If using superproject, add the root repo as a push option. 628 # If using superproject, add the root repo as a push option.
628 manifest = branch.project.manifest 629 manifest = branch.project.manifest
629 push_options = list(opt.push_options) 630 push_options = list(opt.push_options)
630 if manifest.manifestProject.use_superproject: 631 if git_superproject.UseSuperproject(None, manifest):
631 sp = manifest.superproject 632 sp = manifest.superproject
632 if sp: 633 if sp:
633 r_id = sp.repo_id 634 r_id = sp.repo_id
@@ -802,9 +803,10 @@ Gerrit Code Review: https://www.gerritcodereview.com/
802 project_list=pending_proj_names, worktree_list=pending_worktrees 803 project_list=pending_proj_names, worktree_list=pending_worktrees
803 ): 804 ):
804 if LocalSyncState(manifest).IsPartiallySynced(): 805 if LocalSyncState(manifest).IsPartiallySynced():
805 logger.error( 806 logger.info(
806 "Partially synced tree detected. Syncing all projects " 807 "Tip: A partially synced tree was detected. "
807 "may resolve issues you're seeing." 808 "If this failure involves cross-project dependencies, "
809 "a full `repo sync` might help."
808 ) 810 )
809 ret = 1 811 ret = 1
810 if ret: 812 if ret:
diff --git a/tests/test_color.py b/tests/test_color.py
index 923f7e355..8b75d2199 100644
--- a/tests/test_color.py
+++ b/tests/test_color.py
@@ -14,23 +14,17 @@
14 14
15"""Unittests for the color.py module.""" 15"""Unittests for the color.py module."""
16 16
17import os
18
19import pytest 17import pytest
18import utils_for_test
20 19
21import color 20import color
22import git_config 21import git_config
23 22
24 23
25def fixture(*paths: str) -> str:
26 """Return a path relative to test/fixtures."""
27 return os.path.join(os.path.dirname(__file__), "fixtures", *paths)
28
29
30@pytest.fixture 24@pytest.fixture
31def coloring() -> color.Coloring: 25def coloring() -> color.Coloring:
32 """Create a Coloring object for testing.""" 26 """Create a Coloring object for testing."""
33 config_fixture = fixture("test.gitconfig") 27 config_fixture = utils_for_test.FIXTURES_DIR / "test.gitconfig"
34 config = git_config.GitConfig(config_fixture) 28 config = git_config.GitConfig(config_fixture)
35 color.SetDefaultColoring("true") 29 color.SetDefaultColoring("true")
36 return color.Coloring(config, "status") 30 return color.Coloring(config, "status")
diff --git a/tests/test_git_config.py b/tests/test_git_config.py
index 496d97141..2ece4c6e3 100644
--- a/tests/test_git_config.py
+++ b/tests/test_git_config.py
@@ -14,24 +14,19 @@
14 14
15"""Unittests for the git_config.py module.""" 15"""Unittests for the git_config.py module."""
16 16
17import os
18from pathlib import Path 17from pathlib import Path
19from typing import Any 18from typing import Any
20 19
21import pytest 20import pytest
21import utils_for_test
22 22
23import git_config 23import git_config
24 24
25 25
26def fixture_path(*paths: str) -> str:
27 """Return a path relative to test/fixtures."""
28 return os.path.join(os.path.dirname(__file__), "fixtures", *paths)
29
30
31@pytest.fixture 26@pytest.fixture
32def readonly_config() -> git_config.GitConfig: 27def readonly_config() -> git_config.GitConfig:
33 """Create a GitConfig object using the test.gitconfig fixture.""" 28 """Create a GitConfig object using the test.gitconfig fixture."""
34 config_fixture = fixture_path("test.gitconfig") 29 config_fixture = utils_for_test.FIXTURES_DIR / "test.gitconfig"
35 return git_config.GitConfig(config_fixture) 30 return git_config.GitConfig(config_fixture)
36 31
37 32
@@ -63,7 +58,7 @@ def test_get_string_with_true_value(
63 58
64def test_get_string_from_missing_file() -> None: 59def test_get_string_from_missing_file() -> None:
65 """Test missing config file.""" 60 """Test missing config file."""
66 config_fixture = fixture_path("not.present.gitconfig") 61 config_fixture = utils_for_test.FIXTURES_DIR / "not.present.gitconfig"
67 config = git_config.GitConfig(config_fixture) 62 config = git_config.GitConfig(config_fixture)
68 val = config.GetString("empty") 63 val = config.GetString("empty")
69 assert val is None 64 assert val is None
diff --git a/tests/test_git_trace2_event_log.py b/tests/test_git_trace2_event_log.py
index be2d09b07..9a6ba2052 100644
--- a/tests/test_git_trace2_event_log.py
+++ b/tests/test_git_trace2_event_log.py
@@ -18,17 +18,24 @@ import contextlib
18import io 18import io
19import json 19import json
20import os 20import os
21import re
21import socket 22import socket
22import tempfile 23import tempfile
23import threading 24import threading
24import unittest 25from typing import Any, Dict, List, Optional
25from unittest import mock 26from unittest import mock
26 27
28import pytest
29
27import git_trace2_event_log 30import git_trace2_event_log
28import platform_utils 31import platform_utils
29 32
30 33
31def serverLoggingThread(socket_path, server_ready, received_traces): 34def server_logging_thread(
35 socket_path: str,
36 server_ready: threading.Condition,
37 received_traces: List[str],
38) -> None:
32 """Helper function to receive logs over a Unix domain socket. 39 """Helper function to receive logs over a Unix domain socket.
33 40
34 Appends received messages on the provided socket and appends to 41 Appends received messages on the provided socket and appends to
@@ -57,405 +64,425 @@ def serverLoggingThread(socket_path, server_ready, received_traces):
57 received_traces.extend(data.decode("utf-8").splitlines()) 64 received_traces.extend(data.decode("utf-8").splitlines())
58 65
59 66
60class EventLogTestCase(unittest.TestCase): 67PARENT_SID_KEY = "GIT_TRACE2_PARENT_SID"
61 """TestCase for the EventLog module.""" 68PARENT_SID_VALUE = "parent_sid"
69SELF_SID_REGEX = r"repo-\d+T\d+Z-.*"
70FULL_SID_REGEX = rf"^{PARENT_SID_VALUE}/{SELF_SID_REGEX}"
71
72
73@pytest.fixture
74def event_log() -> git_trace2_event_log.EventLog:
75 """Fixture for the EventLog module."""
76 # By default we initialize with the expected case where
77 # repo launches us (so GIT_TRACE2_PARENT_SID is set).
78 env = {PARENT_SID_KEY: PARENT_SID_VALUE}
79 return git_trace2_event_log.EventLog(env=env)
80
81
82def verify_common_keys(
83 log_entry: Dict[str, Any],
84 expected_event_name: Optional[str] = None,
85 full_sid: bool = True,
86) -> None:
87 """Helper function to verify common event log keys."""
88 assert "event" in log_entry
89 assert "sid" in log_entry
90 assert "thread" in log_entry
91 assert "time" in log_entry
92
93 # Do basic data format validation.
94 if expected_event_name:
95 assert expected_event_name == log_entry["event"]
96 if full_sid:
97 assert re.match(FULL_SID_REGEX, log_entry["sid"])
98 else:
99 assert re.match(SELF_SID_REGEX, log_entry["sid"])
100 assert re.match(r"^\d+-\d+-\d+T\d+:\d+:\d+\.\d+\+00:00$", log_entry["time"])
101
102
103def read_log(log_path: str) -> List[Dict[str, Any]]:
104 """Helper function to read log data into a list."""
105 log_data = []
106 with open(log_path, mode="rb") as f:
107 for line in f:
108 log_data.append(json.loads(line))
109 return log_data
110
111
112def remove_prefix(s: str, prefix: str) -> str:
113 """Return a copy string after removing |prefix| from |s|, if present or
114 the original string."""
115 if s.startswith(prefix):
116 return s[len(prefix) :]
117 else:
118 return s
119
120
121def test_initial_state_with_parent_sid(
122 event_log: git_trace2_event_log.EventLog,
123) -> None:
124 """Test initial state when 'GIT_TRACE2_PARENT_SID' is set by parent."""
125 assert re.match(FULL_SID_REGEX, event_log.full_sid)
126
127
128def test_initial_state_no_parent_sid() -> None:
129 """Test initial state when 'GIT_TRACE2_PARENT_SID' is not set."""
130 # Setup an empty environment dict (no parent sid).
131 event_log = git_trace2_event_log.EventLog(env={})
132 assert re.match(SELF_SID_REGEX, event_log.full_sid)
133
134
135def test_version_event(event_log: git_trace2_event_log.EventLog) -> None:
136 """Test 'version' event data is valid.
137
138 Verify that the 'version' event is written even when no other
139 events are added.
62 140
63 PARENT_SID_KEY = "GIT_TRACE2_PARENT_SID" 141 Expected event log:
64 PARENT_SID_VALUE = "parent_sid" 142 <version event>
65 SELF_SID_REGEX = r"repo-\d+T\d+Z-.*" 143 """
66 FULL_SID_REGEX = rf"^{PARENT_SID_VALUE}/{SELF_SID_REGEX}" 144 with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
145 log_path = event_log.Write(path=tempdir)
146 log_data = read_log(log_path)
147
148 # A log with no added events should only have the version entry.
149 assert len(log_data) == 1
150 version_event = log_data[0]
151 verify_common_keys(version_event, expected_event_name="version")
152 # Check for 'version' event specific fields.
153 assert "evt" in version_event
154 assert "exe" in version_event
155 # Verify "evt" version field is a string.
156 assert isinstance(version_event["evt"], str)
157
158
159def test_start_event(event_log: git_trace2_event_log.EventLog) -> None:
160 """Test and validate 'start' event data is valid.
161
162 Expected event log:
163 <version event>
164 <start event>
165 """
166 event_log.StartEvent([])
167 with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
168 log_path = event_log.Write(path=tempdir)
169 log_data = read_log(log_path)
170
171 assert len(log_data) == 2
172 start_event = log_data[1]
173 verify_common_keys(log_data[0], expected_event_name="version")
174 verify_common_keys(start_event, expected_event_name="start")
175 # Check for 'start' event specific fields.
176 assert "argv" in start_event
177 assert isinstance(start_event["argv"], list)
178
179
180def test_exit_event_result_none(
181 event_log: git_trace2_event_log.EventLog,
182) -> None:
183 """Test 'exit' event data is valid when result is None.
184
185 We expect None result to be converted to 0 in the exit event data.
186
187 Expected event log:
188 <version event>
189 <exit event>
190 """
191 event_log.ExitEvent(None)
192 with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
193 log_path = event_log.Write(path=tempdir)
194 log_data = read_log(log_path)
195
196 assert len(log_data) == 2
197 exit_event = log_data[1]
198 verify_common_keys(log_data[0], expected_event_name="version")
199 verify_common_keys(exit_event, expected_event_name="exit")
200 # Check for 'exit' event specific fields.
201 assert "code" in exit_event
202 # 'None' result should convert to 0 (successful) return code.
203 assert exit_event["code"] == 0
204
205
206def test_exit_event_result_integer(
207 event_log: git_trace2_event_log.EventLog,
208) -> None:
209 """Test 'exit' event data is valid when result is an integer.
210
211 Expected event log:
212 <version event>
213 <exit event>
214 """
215 event_log.ExitEvent(2)
216 with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
217 log_path = event_log.Write(path=tempdir)
218 log_data = read_log(log_path)
219
220 assert len(log_data) == 2
221 exit_event = log_data[1]
222 verify_common_keys(log_data[0], expected_event_name="version")
223 verify_common_keys(exit_event, expected_event_name="exit")
224 # Check for 'exit' event specific fields.
225 assert "code" in exit_event
226 assert exit_event["code"] == 2
227
228
229def test_command_event(event_log: git_trace2_event_log.EventLog) -> None:
230 """Test and validate 'command' event data is valid.
231
232 Expected event log:
233 <version event>
234 <command event>
235 """
236 event_log.CommandEvent(name="repo", subcommands=["init", "this"])
237 with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
238 log_path = event_log.Write(path=tempdir)
239 log_data = read_log(log_path)
240
241 assert len(log_data) == 2
242 command_event = log_data[1]
243 verify_common_keys(log_data[0], expected_event_name="version")
244 verify_common_keys(command_event, expected_event_name="cmd_name")
245 # Check for 'command' event specific fields.
246 assert "name" in command_event
247 assert command_event["name"] == "repo-init-this"
248
249
250def test_def_params_event_repo_config(
251 event_log: git_trace2_event_log.EventLog,
252) -> None:
253 """Test 'def_params' event data outputs only repo config keys.
254
255 Expected event log:
256 <version event>
257 <def_param event>
258 <def_param event>
259 """
260 config = {
261 "git.foo": "bar",
262 "repo.partialclone": "true",
263 "repo.partialclonefilter": "blob:none",
264 }
265 event_log.DefParamRepoEvents(config)
266
267 with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
268 log_path = event_log.Write(path=tempdir)
269 log_data = read_log(log_path)
270
271 assert len(log_data) == 3
272 def_param_events = log_data[1:]
273 verify_common_keys(log_data[0], expected_event_name="version")
274
275 for event in def_param_events:
276 verify_common_keys(event, expected_event_name="def_param")
277 # Check for 'def_param' event specific fields.
278 assert "param" in event
279 assert "value" in event
280 assert event["param"].startswith("repo.")
281
282
283def test_def_params_event_no_repo_config(
284 event_log: git_trace2_event_log.EventLog,
285) -> None:
286 """Test 'def_params' event data won't output non-repo config keys.
287
288 Expected event log:
289 <version event>
290 """
291 config = {
292 "git.foo": "bar",
293 "git.core.foo2": "baz",
294 }
295 event_log.DefParamRepoEvents(config)
296
297 with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
298 log_path = event_log.Write(path=tempdir)
299 log_data = read_log(log_path)
300
301 assert len(log_data) == 1
302 verify_common_keys(log_data[0], expected_event_name="version")
67 303
68 def setUp(self):
69 """Load the event_log module every time."""
70 self._event_log = None
71 # By default we initialize with the expected case where
72 # repo launches us (so GIT_TRACE2_PARENT_SID is set).
73 env = {
74 self.PARENT_SID_KEY: self.PARENT_SID_VALUE,
75 }
76 self._event_log = git_trace2_event_log.EventLog(env=env)
77 self._log_data = None
78 304
79 def verifyCommonKeys( 305def test_data_event_config(event_log: git_trace2_event_log.EventLog) -> None:
80 self, log_entry, expected_event_name=None, full_sid=True 306 """Test 'data' event data outputs all config keys.
307
308 Expected event log:
309 <version event>
310 <data event>
311 <data event>
312 """
313 config = {
314 "git.foo": "bar",
315 "repo.partialclone": "false",
316 "repo.syncstate.superproject.hassuperprojecttag": "true",
317 "repo.syncstate.superproject.sys.argv": ["--", "sync", "protobuf"],
318 }
319 prefix_value = "prefix"
320 event_log.LogDataConfigEvents(config, prefix_value)
321
322 with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
323 log_path = event_log.Write(path=tempdir)
324 log_data = read_log(log_path)
325
326 assert len(log_data) == 5
327 data_events = log_data[1:]
328 verify_common_keys(log_data[0], expected_event_name="version")
329
330 for event in data_events:
331 verify_common_keys(event)
332 # Check for 'data' event specific fields.
333 assert "key" in event
334 assert "value" in event
335 key = event["key"]
336 key = remove_prefix(key, f"{prefix_value}/")
337 value = event["value"]
338 assert event_log.GetDataEventName(value) == event["event"]
339 assert key in config
340 assert value == config[key]
341
342
343def test_error_event(event_log: git_trace2_event_log.EventLog) -> None:
344 """Test and validate 'error' event data is valid.
345
346 Expected event log:
347 <version event>
348 <error event>
349 """
350 msg = "invalid option: --cahced"
351 fmt = "invalid option: %s"
352 event_log.ErrorEvent(msg, fmt)
353 with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
354 log_path = event_log.Write(path=tempdir)
355 log_data = read_log(log_path)
356
357 assert len(log_data) == 2
358 error_event = log_data[1]
359 verify_common_keys(log_data[0], expected_event_name="version")
360 verify_common_keys(error_event, expected_event_name="error")
361 # Check for 'error' event specific fields.
362 assert "msg" in error_event
363 assert "fmt" in error_event
364 assert error_event["msg"] == f"RepoErrorEvent:{msg}"
365 assert error_event["fmt"] == f"RepoErrorEvent:{fmt}"
366
367
368def test_write_with_filename(event_log: git_trace2_event_log.EventLog) -> None:
369 """Test Write() with a path to a file exits with None."""
370 assert event_log.Write(path="path/to/file") is None
371
372
373def test_write_with_git_config(
374 tmp_path,
375 event_log: git_trace2_event_log.EventLog,
376) -> None:
377 """Test Write() uses the git config path when 'git config' call succeeds."""
378 with mock.patch.object(
379 event_log,
380 "_GetEventTargetPath",
381 return_value=str(tmp_path),
81 ): 382 ):
82 """Helper function to verify common event log keys.""" 383 assert os.path.dirname(event_log.Write()) == str(tmp_path)
83 self.assertIn("event", log_entry) 384
84 self.assertIn("sid", log_entry) 385
85 self.assertIn("thread", log_entry) 386def test_write_no_git_config(event_log: git_trace2_event_log.EventLog) -> None:
86 self.assertIn("time", log_entry) 387 """Test Write() with no git config variable present exits with None."""
87 388 with mock.patch.object(event_log, "_GetEventTargetPath", return_value=None):
88 # Do basic data format validation. 389 assert event_log.Write() is None
89 if expected_event_name: 390
90 self.assertEqual(expected_event_name, log_entry["event"]) 391
91 if full_sid: 392def test_write_non_string(event_log: git_trace2_event_log.EventLog) -> None:
92 self.assertRegex(log_entry["sid"], self.FULL_SID_REGEX) 393 """Test Write() with non-string type for |path| throws TypeError."""
93 else: 394 with pytest.raises(TypeError):
94 self.assertRegex(log_entry["sid"], self.SELF_SID_REGEX) 395 event_log.Write(path=1234)
95 self.assertRegex( 396
96 log_entry["time"], r"^\d+-\d+-\d+T\d+:\d+:\d+\.\d+\+00:00$" 397
398@pytest.mark.skipif(
399 not hasattr(socket, "AF_UNIX"), reason="Requires AF_UNIX sockets"
400)
401def test_write_socket(event_log: git_trace2_event_log.EventLog) -> None:
402 """Test Write() with Unix domain socket and validate received traces."""
403 received_traces: List[str] = []
404 with tempfile.TemporaryDirectory(prefix="test_server_sockets") as tempdir:
405 socket_path = os.path.join(tempdir, "server.sock")
406 server_ready = threading.Condition()
407 # Start "server" listening on Unix domain socket at socket_path.
408 server_thread = threading.Thread(
409 target=server_logging_thread,
410 args=(socket_path, server_ready, received_traces),
97 ) 411 )
412 try:
413 server_thread.start()
98 414
99 def readLog(self, log_path): 415 with server_ready:
100 """Helper function to read log data into a list.""" 416 server_ready.wait(timeout=120)
101 log_data = []
102 with open(log_path, mode="rb") as f:
103 for line in f:
104 log_data.append(json.loads(line))
105 return log_data
106
107 def remove_prefix(self, s, prefix):
108 """Return a copy string after removing |prefix| from |s|, if present or
109 the original string."""
110 if s.startswith(prefix):
111 return s[len(prefix) :]
112 else:
113 return s
114
115 def test_initial_state_with_parent_sid(self):
116 """Test initial state when 'GIT_TRACE2_PARENT_SID' is set by parent."""
117 self.assertRegex(self._event_log.full_sid, self.FULL_SID_REGEX)
118
119 def test_initial_state_no_parent_sid(self):
120 """Test initial state when 'GIT_TRACE2_PARENT_SID' is not set."""
121 # Setup an empty environment dict (no parent sid).
122 self._event_log = git_trace2_event_log.EventLog(env={})
123 self.assertRegex(self._event_log.full_sid, self.SELF_SID_REGEX)
124
125 def test_version_event(self):
126 """Test 'version' event data is valid.
127
128 Verify that the 'version' event is written even when no other
129 events are addded.
130
131 Expected event log:
132 <version event>
133 """
134 with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
135 log_path = self._event_log.Write(path=tempdir)
136 self._log_data = self.readLog(log_path)
137
138 # A log with no added events should only have the version entry.
139 self.assertEqual(len(self._log_data), 1)
140 version_event = self._log_data[0]
141 self.verifyCommonKeys(version_event, expected_event_name="version")
142 # Check for 'version' event specific fields.
143 self.assertIn("evt", version_event)
144 self.assertIn("exe", version_event)
145 # Verify "evt" version field is a string.
146 self.assertIsInstance(version_event["evt"], str)
147
148 def test_start_event(self):
149 """Test and validate 'start' event data is valid.
150
151 Expected event log:
152 <version event>
153 <start event>
154 """
155 self._event_log.StartEvent([])
156 with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
157 log_path = self._event_log.Write(path=tempdir)
158 self._log_data = self.readLog(log_path)
159
160 self.assertEqual(len(self._log_data), 2)
161 start_event = self._log_data[1]
162 self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
163 self.verifyCommonKeys(start_event, expected_event_name="start")
164 # Check for 'start' event specific fields.
165 self.assertIn("argv", start_event)
166 self.assertTrue(isinstance(start_event["argv"], list))
167
168 def test_exit_event_result_none(self):
169 """Test 'exit' event data is valid when result is None.
170
171 We expect None result to be converted to 0 in the exit event data.
172
173 Expected event log:
174 <version event>
175 <exit event>
176 """
177 self._event_log.ExitEvent(None)
178 with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
179 log_path = self._event_log.Write(path=tempdir)
180 self._log_data = self.readLog(log_path)
181
182 self.assertEqual(len(self._log_data), 2)
183 exit_event = self._log_data[1]
184 self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
185 self.verifyCommonKeys(exit_event, expected_event_name="exit")
186 # Check for 'exit' event specific fields.
187 self.assertIn("code", exit_event)
188 # 'None' result should convert to 0 (successful) return code.
189 self.assertEqual(exit_event["code"], 0)
190
191 def test_exit_event_result_integer(self):
192 """Test 'exit' event data is valid when result is an integer.
193
194 Expected event log:
195 <version event>
196 <exit event>
197 """
198 self._event_log.ExitEvent(2)
199 with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
200 log_path = self._event_log.Write(path=tempdir)
201 self._log_data = self.readLog(log_path)
202
203 self.assertEqual(len(self._log_data), 2)
204 exit_event = self._log_data[1]
205 self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
206 self.verifyCommonKeys(exit_event, expected_event_name="exit")
207 # Check for 'exit' event specific fields.
208 self.assertIn("code", exit_event)
209 self.assertEqual(exit_event["code"], 2)
210
211 def test_command_event(self):
212 """Test and validate 'command' event data is valid.
213
214 Expected event log:
215 <version event>
216 <command event>
217 """
218 self._event_log.CommandEvent(name="repo", subcommands=["init", "this"])
219 with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
220 log_path = self._event_log.Write(path=tempdir)
221 self._log_data = self.readLog(log_path)
222
223 self.assertEqual(len(self._log_data), 2)
224 command_event = self._log_data[1]
225 self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
226 self.verifyCommonKeys(command_event, expected_event_name="cmd_name")
227 # Check for 'command' event specific fields.
228 self.assertIn("name", command_event)
229 self.assertEqual(command_event["name"], "repo-init-this")
230
231 def test_def_params_event_repo_config(self):
232 """Test 'def_params' event data outputs only repo config keys.
233
234 Expected event log:
235 <version event>
236 <def_param event>
237 <def_param event>
238 """
239 config = {
240 "git.foo": "bar",
241 "repo.partialclone": "true",
242 "repo.partialclonefilter": "blob:none",
243 }
244 self._event_log.DefParamRepoEvents(config)
245
246 with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
247 log_path = self._event_log.Write(path=tempdir)
248 self._log_data = self.readLog(log_path)
249
250 self.assertEqual(len(self._log_data), 3)
251 def_param_events = self._log_data[1:]
252 self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
253
254 for event in def_param_events:
255 self.verifyCommonKeys(event, expected_event_name="def_param")
256 # Check for 'def_param' event specific fields.
257 self.assertIn("param", event)
258 self.assertIn("value", event)
259 self.assertTrue(event["param"].startswith("repo."))
260
261 def test_def_params_event_no_repo_config(self):
262 """Test 'def_params' event data won't output non-repo config keys.
263
264 Expected event log:
265 <version event>
266 """
267 config = {
268 "git.foo": "bar",
269 "git.core.foo2": "baz",
270 }
271 self._event_log.DefParamRepoEvents(config)
272
273 with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
274 log_path = self._event_log.Write(path=tempdir)
275 self._log_data = self.readLog(log_path)
276
277 self.assertEqual(len(self._log_data), 1)
278 self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
279
280 def test_data_event_config(self):
281 """Test 'data' event data outputs all config keys.
282
283 Expected event log:
284 <version event>
285 <data event>
286 <data event>
287 """
288 config = {
289 "git.foo": "bar",
290 "repo.partialclone": "false",
291 "repo.syncstate.superproject.hassuperprojecttag": "true",
292 "repo.syncstate.superproject.sys.argv": ["--", "sync", "protobuf"],
293 }
294 prefix_value = "prefix"
295 self._event_log.LogDataConfigEvents(config, prefix_value)
296
297 with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
298 log_path = self._event_log.Write(path=tempdir)
299 self._log_data = self.readLog(log_path)
300
301 self.assertEqual(len(self._log_data), 5)
302 data_events = self._log_data[1:]
303 self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
304
305 for event in data_events:
306 self.verifyCommonKeys(event)
307 # Check for 'data' event specific fields.
308 self.assertIn("key", event)
309 self.assertIn("value", event)
310 key = event["key"]
311 key = self.remove_prefix(key, f"{prefix_value}/")
312 value = event["value"]
313 self.assertEqual(
314 self._event_log.GetDataEventName(value), event["event"]
315 )
316 self.assertTrue(key in config and value == config[key])
317
318 def test_error_event(self):
319 """Test and validate 'error' event data is valid.
320
321 Expected event log:
322 <version event>
323 <error event>
324 """
325 msg = "invalid option: --cahced"
326 fmt = "invalid option: %s"
327 self._event_log.ErrorEvent(msg, fmt)
328 with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
329 log_path = self._event_log.Write(path=tempdir)
330 self._log_data = self.readLog(log_path)
331
332 self.assertEqual(len(self._log_data), 2)
333 error_event = self._log_data[1]
334 self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
335 self.verifyCommonKeys(error_event, expected_event_name="error")
336 # Check for 'error' event specific fields.
337 self.assertIn("msg", error_event)
338 self.assertIn("fmt", error_event)
339 self.assertEqual(error_event["msg"], f"RepoErrorEvent:{msg}")
340 self.assertEqual(error_event["fmt"], f"RepoErrorEvent:{fmt}")
341
342 def test_write_with_filename(self):
343 """Test Write() with a path to a file exits with None."""
344 self.assertIsNone(self._event_log.Write(path="path/to/file"))
345
346 def test_write_with_git_config(self):
347 """Test Write() uses the git config path when 'git config' call
348 succeeds."""
349 with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
350 with mock.patch.object(
351 self._event_log,
352 "_GetEventTargetPath",
353 return_value=tempdir,
354 ):
355 self.assertEqual(
356 os.path.dirname(self._event_log.Write()), tempdir
357 )
358
359 def test_write_no_git_config(self):
360 """Test Write() with no git config variable present exits with None."""
361 with mock.patch.object(
362 self._event_log, "_GetEventTargetPath", return_value=None
363 ):
364 self.assertIsNone(self._event_log.Write())
365
366 def test_write_non_string(self):
367 """Test Write() with non-string type for |path| throws TypeError."""
368 with self.assertRaises(TypeError):
369 self._event_log.Write(path=1234)
370
371 @unittest.skipIf(not hasattr(socket, "AF_UNIX"), "Requires AF_UNIX sockets")
372 def test_write_socket(self):
373 """Test Write() with Unix domain socket for |path| and validate received
374 traces."""
375 received_traces = []
376 with tempfile.TemporaryDirectory(
377 prefix="test_server_sockets"
378 ) as tempdir:
379 socket_path = os.path.join(tempdir, "server.sock")
380 server_ready = threading.Condition()
381 # Start "server" listening on Unix domain socket at socket_path.
382 server_thread = threading.Thread(
383 target=serverLoggingThread,
384 args=(socket_path, server_ready, received_traces),
385 )
386 try:
387 server_thread.start()
388
389 with server_ready:
390 server_ready.wait(timeout=120)
391 417
392 self._event_log.StartEvent([]) 418 event_log.StartEvent([])
393 path = self._event_log.Write(path=f"af_unix:{socket_path}") 419 path = event_log.Write(path=f"af_unix:{socket_path}")
394 finally: 420 finally:
395 server_thread.join(timeout=5) 421 server_thread.join(timeout=5)
396 422
397 self.assertEqual(path, f"af_unix:stream:{socket_path}") 423 assert path == f"af_unix:stream:{socket_path}"
398 self.assertEqual(len(received_traces), 2) 424 assert len(received_traces) == 2
399 version_event = json.loads(received_traces[0]) 425 version_event = json.loads(received_traces[0])
400 start_event = json.loads(received_traces[1]) 426 start_event = json.loads(received_traces[1])
401 self.verifyCommonKeys(version_event, expected_event_name="version") 427 verify_common_keys(version_event, expected_event_name="version")
402 self.verifyCommonKeys(start_event, expected_event_name="start") 428 verify_common_keys(start_event, expected_event_name="start")
403 # Check for 'start' event specific fields. 429 # Check for 'start' event specific fields.
404 self.assertIn("argv", start_event) 430 assert "argv" in start_event
405 self.assertIsInstance(start_event["argv"], list) 431 assert isinstance(start_event["argv"], list)
406 432
407 433
408class EventLogVerboseTestCase(unittest.TestCase): 434class TestEventLogVerbose:
409 """TestCase for the EventLog module verbose logging.""" 435 """TestCase for the EventLog module verbose logging."""
410 436
411 def setUp(self): 437 def test_write_socket_error_no_verbose(self) -> None:
412 self._event_log = git_trace2_event_log.EventLog(env={})
413
414 def test_write_socket_error_no_verbose(self):
415 """Test Write() suppression of socket errors when not verbose.""" 438 """Test Write() suppression of socket errors when not verbose."""
416 self._event_log.verbose = False 439 event_log = git_trace2_event_log.EventLog(env={})
440 event_log.verbose = False
417 with contextlib.redirect_stderr( 441 with contextlib.redirect_stderr(
418 io.StringIO() 442 io.StringIO()
419 ) as mock_stderr, mock.patch("socket.socket", side_effect=OSError): 443 ) as mock_stderr, mock.patch("socket.socket", side_effect=OSError):
420 self._event_log.Write(path="af_unix:stream:/tmp/test_sock") 444 event_log.Write(path="af_unix:stream:/tmp/test_sock")
421 self.assertEqual(mock_stderr.getvalue(), "") 445 assert mock_stderr.getvalue() == ""
422 446
423 def test_write_socket_error_verbose(self): 447 def test_write_socket_error_verbose(self) -> None:
424 """Test Write() printing of socket errors when verbose.""" 448 """Test Write() printing of socket errors when verbose."""
425 self._event_log.verbose = True 449 event_log = git_trace2_event_log.EventLog(env={})
450 event_log.verbose = True
426 with contextlib.redirect_stderr( 451 with contextlib.redirect_stderr(
427 io.StringIO() 452 io.StringIO()
428 ) as mock_stderr, mock.patch( 453 ) as mock_stderr, mock.patch(
429 "socket.socket", side_effect=OSError("Mock error") 454 "socket.socket", side_effect=OSError("Mock error")
430 ): 455 ):
431 self._event_log.Write(path="af_unix:stream:/tmp/test_sock") 456 event_log.Write(path="af_unix:stream:/tmp/test_sock")
432 self.assertIn( 457 assert (
433 "git trace2 logging failed: Mock error", 458 "git trace2 logging failed: Mock error"
434 mock_stderr.getvalue(), 459 in mock_stderr.getvalue()
435 ) 460 )
436 461
437 def test_write_file_error_no_verbose(self): 462 def test_write_file_error_no_verbose(self) -> None:
438 """Test Write() suppression of file errors when not verbose.""" 463 """Test Write() suppression of file errors when not verbose."""
439 self._event_log.verbose = False 464 event_log = git_trace2_event_log.EventLog(env={})
465 event_log.verbose = False
440 with contextlib.redirect_stderr( 466 with contextlib.redirect_stderr(
441 io.StringIO() 467 io.StringIO()
442 ) as mock_stderr, mock.patch( 468 ) as mock_stderr, mock.patch(
443 "tempfile.NamedTemporaryFile", side_effect=FileExistsError 469 "tempfile.NamedTemporaryFile", side_effect=FileExistsError
444 ): 470 ):
445 self._event_log.Write(path="/tmp") 471 event_log.Write(path="/tmp")
446 self.assertEqual(mock_stderr.getvalue(), "") 472 assert mock_stderr.getvalue() == ""
447 473
448 def test_write_file_error_verbose(self): 474 def test_write_file_error_verbose(self) -> None:
449 """Test Write() printing of file errors when verbose.""" 475 """Test Write() printing of file errors when verbose."""
450 self._event_log.verbose = True 476 event_log = git_trace2_event_log.EventLog(env={})
477 event_log.verbose = True
451 with contextlib.redirect_stderr( 478 with contextlib.redirect_stderr(
452 io.StringIO() 479 io.StringIO()
453 ) as mock_stderr, mock.patch( 480 ) as mock_stderr, mock.patch(
454 "tempfile.NamedTemporaryFile", 481 "tempfile.NamedTemporaryFile",
455 side_effect=FileExistsError("Mock error"), 482 side_effect=FileExistsError("Mock error"),
456 ): 483 ):
457 self._event_log.Write(path="/tmp") 484 event_log.Write(path="/tmp")
458 self.assertIn( 485 assert (
459 "git trace2 logging failed: FileExistsError", 486 "git trace2 logging failed: FileExistsError"
460 mock_stderr.getvalue(), 487 in mock_stderr.getvalue()
461 ) 488 )
diff --git a/tests/test_main.py b/tests/test_main.py
new file mode 100644
index 000000000..21bb29c43
--- /dev/null
+++ b/tests/test_main.py
@@ -0,0 +1,166 @@
1# Copyright (C) 2026 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Tests for the main repo script and subcommand routing."""
16
17from unittest import mock
18
19import pytest
20
21from main import _Repo
22
23
24@pytest.fixture(name="repo")
25def fixture_repo():
26 repo = _Repo("repodir")
27 # Overriding the command list here ensures that we are only testing
28 # against a fixed set of commands, reducing fragility to new
29 # subcommands being added to the main repo tool.
30 repo.commands = {"start": None, "sync": None, "smart": None}
31 return repo
32
33
34@pytest.fixture(name="mock_config")
35def fixture_mock_config():
36 return mock.MagicMock()
37
38
39@mock.patch("time.sleep")
40def test_autocorrect_delay(mock_sleep, repo, mock_config):
41 """Test autocorrect with positive delay."""
42 mock_config.GetString.return_value = "10"
43
44 res = repo._autocorrect_command_name("tart", mock_config)
45
46 mock_config.GetString.assert_called_with("help.autocorrect")
47 mock_sleep.assert_called_with(1.0)
48 assert res == "start"
49
50
51@mock.patch("time.sleep")
52def test_autocorrect_delay_one(mock_sleep, repo, mock_config):
53 """Test autocorrect with '1' (0.1s delay, not immediate)."""
54 mock_config.GetString.return_value = "1"
55
56 res = repo._autocorrect_command_name("tart", mock_config)
57
58 mock_sleep.assert_called_with(0.1)
59 assert res == "start"
60
61
62@mock.patch("time.sleep", side_effect=KeyboardInterrupt())
63def test_autocorrect_delay_interrupt(mock_sleep, repo, mock_config):
64 """Test autocorrect handles KeyboardInterrupt during delay."""
65 mock_config.GetString.return_value = "10"
66
67 res = repo._autocorrect_command_name("tart", mock_config)
68
69 mock_sleep.assert_called_with(1.0)
70 assert res is None
71
72
73@mock.patch("time.sleep")
74def test_autocorrect_immediate(mock_sleep, repo, mock_config):
75 """Test autocorrect with immediate/negative delay."""
76 # Test numeric negative.
77 mock_config.GetString.return_value = "-1"
78 res = repo._autocorrect_command_name("tart", mock_config)
79 mock_sleep.assert_not_called()
80 assert res == "start"
81
82 # Test string boolean "true".
83 mock_config.GetString.return_value = "true"
84 res = repo._autocorrect_command_name("tart", mock_config)
85 mock_sleep.assert_not_called()
86 assert res == "start"
87
88 # Test string boolean "yes".
89 mock_config.GetString.return_value = "YES"
90 res = repo._autocorrect_command_name("tart", mock_config)
91 mock_sleep.assert_not_called()
92 assert res == "start"
93
94 # Test string boolean "immediate".
95 mock_config.GetString.return_value = "Immediate"
96 res = repo._autocorrect_command_name("tart", mock_config)
97 mock_sleep.assert_not_called()
98 assert res == "start"
99
100
101def test_autocorrect_zero_or_show(repo, mock_config):
102 """Test autocorrect with zero delay (suggestions only)."""
103 # Test numeric zero.
104 mock_config.GetString.return_value = "0"
105 res = repo._autocorrect_command_name("tart", mock_config)
106 assert res is None
107
108 # Test string boolean "false".
109 mock_config.GetString.return_value = "False"
110 res = repo._autocorrect_command_name("tart", mock_config)
111 assert res is None
112
113 # Test string boolean "show".
114 mock_config.GetString.return_value = "show"
115 res = repo._autocorrect_command_name("tart", mock_config)
116 assert res is None
117
118
119def test_autocorrect_never(repo, mock_config):
120 """Test autocorrect with 'never'."""
121 mock_config.GetString.return_value = "never"
122 res = repo._autocorrect_command_name("tart", mock_config)
123 assert res is None
124
125
126@mock.patch("builtins.input", return_value="y")
127def test_autocorrect_prompt_yes(mock_input, repo, mock_config):
128 """Test autocorrect with prompt and user answers yes."""
129 mock_config.GetString.return_value = "prompt"
130
131 res = repo._autocorrect_command_name("tart", mock_config)
132
133 assert res == "start"
134
135
136@mock.patch("builtins.input", return_value="n")
137def test_autocorrect_prompt_no(mock_input, repo, mock_config):
138 """Test autocorrect with prompt and user answers no."""
139 mock_config.GetString.return_value = "prompt"
140
141 res = repo._autocorrect_command_name("tart", mock_config)
142
143 assert res is None
144
145
146@mock.patch("builtins.input", return_value="y")
147def test_autocorrect_multiple_candidates(mock_input, repo, mock_config):
148 """Test autocorrect with multiple matches forces a prompt."""
149 mock_config.GetString.return_value = "10" # Normally just delay
150
151 # 'snart' matches both 'start' and 'smart' with > 0.7 ratio
152 res = repo._autocorrect_command_name("snart", mock_config)
153
154 # Because there are multiple candidates, it should prompt
155 mock_input.assert_called_once()
156 assert res == "start"
157
158
159@mock.patch("builtins.input", side_effect=KeyboardInterrupt())
160def test_autocorrect_prompt_interrupt(mock_input, repo, mock_config):
161 """Test autocorrect with prompt and user interrupts."""
162 mock_config.GetString.return_value = "prompt"
163
164 res = repo._autocorrect_command_name("tart", mock_config)
165
166 assert res is None
diff --git a/tests/test_manifest_xml.py b/tests/test_manifest_xml.py
index 5e0c78334..c7352f89a 100644
--- a/tests/test_manifest_xml.py
+++ b/tests/test_manifest_xml.py
@@ -18,10 +18,10 @@ import os
18from pathlib import Path 18from pathlib import Path
19import platform 19import platform
20import re 20import re
21import tempfile
22import unittest
23import xml.dom.minidom 21import xml.dom.minidom
24 22
23import pytest
24
25import error 25import error
26import manifest_xml 26import manifest_xml
27 27
@@ -66,7 +66,7 @@ if os.path.sep != "/":
66 ) 66 )
67 67
68 68
69def sort_attributes(manifest): 69def sort_attributes(manifest: str) -> str:
70 """Sort the attributes of all elements alphabetically. 70 """Sort the attributes of all elements alphabetically.
71 71
72 This is needed because different versions of the toxml() function from 72 This is needed because different versions of the toxml() function from
@@ -93,13 +93,12 @@ def sort_attributes(manifest):
93 return new_manifest 93 return new_manifest
94 94
95 95
96class ManifestParseTestCase(unittest.TestCase): 96class RepoClient:
97 """TestCase for parsing manifests.""" 97 """Basic empty repo checkout."""
98 98
99 def setUp(self): 99 def __init__(self, topdir: Path):
100 self.tempdirobj = tempfile.TemporaryDirectory(prefix="repo_tests") 100 self.topdir = topdir
101 self.tempdir = Path(self.tempdirobj.name) 101 self.repodir = self.topdir / ".repo"
102 self.repodir = self.tempdir / ".repo"
103 self.manifest_dir = self.repodir / "manifests" 102 self.manifest_dir = self.repodir / "manifests"
104 self.manifest_file = self.repodir / manifest_xml.MANIFEST_FILE_NAME 103 self.manifest_file = self.repodir / manifest_xml.MANIFEST_FILE_NAME
105 self.local_manifest_dir = ( 104 self.local_manifest_dir = (
@@ -107,7 +106,6 @@ class ManifestParseTestCase(unittest.TestCase):
107 ) 106 )
108 self.repodir.mkdir() 107 self.repodir.mkdir()
109 self.manifest_dir.mkdir() 108 self.manifest_dir.mkdir()
110
111 # The manifest parsing really wants a git repo currently. 109 # The manifest parsing really wants a git repo currently.
112 gitdir = self.repodir / "manifests.git" 110 gitdir = self.repodir / "manifests.git"
113 gitdir.mkdir() 111 gitdir.mkdir()
@@ -117,10 +115,7 @@ class ManifestParseTestCase(unittest.TestCase):
117""" 115"""
118 ) 116 )
119 117
120 def tearDown(self): 118 def get_xml_manifest(self, data: str) -> manifest_xml.XmlManifest:
121 self.tempdirobj.cleanup()
122
123 def getXmlManifest(self, data):
124 """Helper to initialize a manifest for testing.""" 119 """Helper to initialize a manifest for testing."""
125 self.manifest_file.write_text(data, encoding="utf-8") 120 self.manifest_file.write_text(data, encoding="utf-8")
126 return manifest_xml.XmlManifest( 121 return manifest_xml.XmlManifest(
@@ -128,33 +123,43 @@ class ManifestParseTestCase(unittest.TestCase):
128 ) 123 )
129 124
130 @staticmethod 125 @staticmethod
131 def encodeXmlAttr(attr): 126 def encode_xml_attr(attr: str) -> str:
132 """Encode |attr| using XML escape rules.""" 127 """Encode |attr| using XML escape rules."""
133 return attr.replace("\r", "&#x000d;").replace("\n", "&#x000a;") 128 return attr.replace("\r", "&#x000d;").replace("\n", "&#x000a;")
134 129
135 130
136class ManifestValidateFilePaths(unittest.TestCase): 131@pytest.fixture
132def repo_client(tmp_path: Path) -> RepoClient:
133 """Generate a basic empty repo checkout.
134
135 The manifest is not generated.
136 """
137 return RepoClient(tmp_path)
138
139
140class TestManifestValidateFilePaths:
137 """Check _ValidateFilePaths helper. 141 """Check _ValidateFilePaths helper.
138 142
139 This doesn't access a real filesystem. 143 This doesn't access a real filesystem.
140 """ 144 """
141 145
142 def check_both(self, *args): 146 def check_both(self, src: str, dest: str) -> None:
143 manifest_xml.XmlManifest._ValidateFilePaths("copyfile", *args) 147 """Check copyfile & linkfile."""
144 manifest_xml.XmlManifest._ValidateFilePaths("linkfile", *args) 148 manifest_xml.XmlManifest._ValidateFilePaths("copyfile", src, dest)
149 manifest_xml.XmlManifest._ValidateFilePaths("linkfile", src, dest)
145 150
146 def test_normal_path(self): 151 def test_normal_path(self) -> None:
147 """Make sure good paths are accepted.""" 152 """Make sure good paths are accepted."""
148 self.check_both("foo", "bar") 153 self.check_both("foo", "bar")
149 self.check_both("foo/bar", "bar") 154 self.check_both("foo/bar", "bar")
150 self.check_both("foo", "bar/bar") 155 self.check_both("foo", "bar/bar")
151 self.check_both("foo/bar", "bar/bar") 156 self.check_both("foo/bar", "bar/bar")
152 157
153 def test_symlink_targets(self): 158 def test_symlink_targets(self) -> None:
154 """Some extra checks for symlinks.""" 159 """Some extra checks for symlinks."""
155 160
156 def check(*args): 161 def check(src: str, dest: str) -> None:
157 manifest_xml.XmlManifest._ValidateFilePaths("linkfile", *args) 162 manifest_xml.XmlManifest._ValidateFilePaths("linkfile", src, dest)
158 163
159 # We allow symlinks to end in a slash since we allow them to point to 164 # We allow symlinks to end in a slash since we allow them to point to
160 # dirs in general. Technically the slash isn't necessary. 165 # dirs in general. Technically the slash isn't necessary.
@@ -162,114 +167,111 @@ class ManifestValidateFilePaths(unittest.TestCase):
162 # We allow a single '.' to get a reference to the project itself. 167 # We allow a single '.' to get a reference to the project itself.
163 check(".", "bar") 168 check(".", "bar")
164 169
165 def test_bad_paths(self): 170 def test_bad_paths(self) -> None:
166 """Make sure bad paths (src & dest) are rejected.""" 171 """Make sure bad paths (src & dest) are rejected."""
167 for path in INVALID_FS_PATHS: 172 for path in INVALID_FS_PATHS:
168 self.assertRaises( 173 with pytest.raises(error.ManifestInvalidPathError):
169 error.ManifestInvalidPathError, self.check_both, path, "a" 174 self.check_both(path, "a")
170 ) 175 with pytest.raises(error.ManifestInvalidPathError):
171 self.assertRaises( 176 self.check_both("a", path)
172 error.ManifestInvalidPathError, self.check_both, "a", path
173 )
174 177
175 178
176class ValueTests(unittest.TestCase): 179class TestValue:
177 """Check utility parsing code.""" 180 """Check utility parsing code."""
178 181
179 def _get_node(self, text): 182 def _get_node(self, text: str) -> xml.dom.minidom.Element:
180 return xml.dom.minidom.parseString(text).firstChild 183 return xml.dom.minidom.parseString(text).firstChild
181 184
182 def test_bool_default(self): 185 def test_bool_default(self) -> None:
183 """Check XmlBool default handling.""" 186 """Check XmlBool default handling."""
184 node = self._get_node("<node/>") 187 node = self._get_node("<node/>")
185 self.assertIsNone(manifest_xml.XmlBool(node, "a")) 188 assert manifest_xml.XmlBool(node, "a") is None
186 self.assertIsNone(manifest_xml.XmlBool(node, "a", None)) 189 assert manifest_xml.XmlBool(node, "a", None) is None
187 self.assertEqual(123, manifest_xml.XmlBool(node, "a", 123)) 190 assert manifest_xml.XmlBool(node, "a", 123) == 123
188 191
189 node = self._get_node('<node a=""/>') 192 node = self._get_node('<node a=""/>')
190 self.assertIsNone(manifest_xml.XmlBool(node, "a")) 193 assert manifest_xml.XmlBool(node, "a") is None
191 194
192 def test_bool_invalid(self): 195 def test_bool_invalid(self) -> None:
193 """Check XmlBool invalid handling.""" 196 """Check XmlBool invalid handling."""
194 node = self._get_node('<node a="moo"/>') 197 node = self._get_node('<node a="moo"/>')
195 self.assertEqual(123, manifest_xml.XmlBool(node, "a", 123)) 198 assert manifest_xml.XmlBool(node, "a", 123) == 123
196 199
197 def test_bool_true(self): 200 def test_bool_true(self) -> None:
198 """Check XmlBool true values.""" 201 """Check XmlBool true values."""
199 for value in ("yes", "true", "1"): 202 for value in ("yes", "true", "1"):
200 node = self._get_node(f'<node a="{value}"/>') 203 node = self._get_node(f'<node a="{value}"/>')
201 self.assertTrue(manifest_xml.XmlBool(node, "a")) 204 assert manifest_xml.XmlBool(node, "a") is True
202 205
203 def test_bool_false(self): 206 def test_bool_false(self) -> None:
204 """Check XmlBool false values.""" 207 """Check XmlBool false values."""
205 for value in ("no", "false", "0"): 208 for value in ("no", "false", "0"):
206 node = self._get_node(f'<node a="{value}"/>') 209 node = self._get_node(f'<node a="{value}"/>')
207 self.assertFalse(manifest_xml.XmlBool(node, "a")) 210 assert manifest_xml.XmlBool(node, "a") is False
208 211
209 def test_int_default(self): 212 def test_int_default(self) -> None:
210 """Check XmlInt default handling.""" 213 """Check XmlInt default handling."""
211 node = self._get_node("<node/>") 214 node = self._get_node("<node/>")
212 self.assertIsNone(manifest_xml.XmlInt(node, "a")) 215 assert manifest_xml.XmlInt(node, "a") is None
213 self.assertIsNone(manifest_xml.XmlInt(node, "a", None)) 216 assert manifest_xml.XmlInt(node, "a", None) is None
214 self.assertEqual(123, manifest_xml.XmlInt(node, "a", 123)) 217 assert manifest_xml.XmlInt(node, "a", 123) == 123
215 218
216 node = self._get_node('<node a=""/>') 219 node = self._get_node('<node a=""/>')
217 self.assertIsNone(manifest_xml.XmlInt(node, "a")) 220 assert manifest_xml.XmlInt(node, "a") is None
218 221
219 def test_int_good(self): 222 def test_int_good(self) -> None:
220 """Check XmlInt numeric handling.""" 223 """Check XmlInt numeric handling."""
221 for value in (-1, 0, 1, 50000): 224 for value in (-1, 0, 1, 50000):
222 node = self._get_node(f'<node a="{value}"/>') 225 node = self._get_node(f'<node a="{value}"/>')
223 self.assertEqual(value, manifest_xml.XmlInt(node, "a")) 226 assert manifest_xml.XmlInt(node, "a") == value
224 227
225 def test_int_invalid(self): 228 def test_int_invalid(self) -> None:
226 """Check XmlInt invalid handling.""" 229 """Check XmlInt invalid handling."""
227 with self.assertRaises(error.ManifestParseError): 230 with pytest.raises(error.ManifestParseError):
228 node = self._get_node('<node a="xx"/>') 231 node = self._get_node('<node a="xx"/>')
229 manifest_xml.XmlInt(node, "a") 232 manifest_xml.XmlInt(node, "a")
230 233
231 234
232class XmlManifestTests(ManifestParseTestCase): 235class TestXmlManifest:
233 """Check manifest processing.""" 236 """Check manifest processing."""
234 237
235 def test_empty(self): 238 def test_empty(self, repo_client: RepoClient) -> None:
236 """Parse an 'empty' manifest file.""" 239 """Parse an 'empty' manifest file."""
237 manifest = self.getXmlManifest( 240 manifest = repo_client.get_xml_manifest(
238 '<?xml version="1.0" encoding="UTF-8"?>' "<manifest></manifest>" 241 '<?xml version="1.0" encoding="UTF-8"?>' "<manifest></manifest>"
239 ) 242 )
240 self.assertEqual(manifest.remotes, {}) 243 assert manifest.remotes == {}
241 self.assertEqual(manifest.projects, []) 244 assert manifest.projects == []
242 245
243 def test_link(self): 246 def test_link(self, repo_client: RepoClient) -> None:
244 """Verify Link handling with new names.""" 247 """Verify Link handling with new names."""
245 manifest = manifest_xml.XmlManifest( 248 manifest = repo_client.get_xml_manifest("<manifest></manifest>")
246 str(self.repodir), str(self.manifest_file) 249 (repo_client.manifest_dir / "foo.xml").write_text(
250 "<manifest></manifest>"
247 ) 251 )
248 (self.manifest_dir / "foo.xml").write_text("<manifest></manifest>")
249 manifest.Link("foo.xml") 252 manifest.Link("foo.xml")
250 self.assertIn( 253 assert (
251 '<include name="foo.xml" />', self.manifest_file.read_text() 254 '<include name="foo.xml" />'
255 in repo_client.manifest_file.read_text()
252 ) 256 )
253 257
254 def test_toxml_empty(self): 258 def test_toxml_empty(self, repo_client: RepoClient) -> None:
255 """Verify the ToXml() helper.""" 259 """Verify the ToXml() helper."""
256 manifest = self.getXmlManifest( 260 manifest = repo_client.get_xml_manifest(
257 '<?xml version="1.0" encoding="UTF-8"?>' "<manifest></manifest>" 261 '<?xml version="1.0" encoding="UTF-8"?>' "<manifest></manifest>"
258 ) 262 )
259 self.assertEqual( 263 assert manifest.ToXml().toxml() == '<?xml version="1.0" ?><manifest/>'
260 manifest.ToXml().toxml(), '<?xml version="1.0" ?><manifest/>'
261 )
262 264
263 def test_todict_empty(self): 265 def test_todict_empty(self, repo_client: RepoClient) -> None:
264 """Verify the ToDict() helper.""" 266 """Verify the ToDict() helper."""
265 manifest = self.getXmlManifest( 267 manifest = repo_client.get_xml_manifest(
266 '<?xml version="1.0" encoding="UTF-8"?>' "<manifest></manifest>" 268 '<?xml version="1.0" encoding="UTF-8"?>' "<manifest></manifest>"
267 ) 269 )
268 self.assertEqual(manifest.ToDict(), {}) 270 assert manifest.ToDict() == {}
269 271
270 def test_toxml_omit_local(self): 272 def test_toxml_omit_local(self, repo_client: RepoClient) -> None:
271 """Does not include local_manifests projects when omit_local=True.""" 273 """Does not include local_manifests projects when omit_local=True."""
272 manifest = self.getXmlManifest( 274 manifest = repo_client.get_xml_manifest(
273 '<?xml version="1.0" encoding="UTF-8"?><manifest>' 275 '<?xml version="1.0" encoding="UTF-8"?><manifest>'
274 '<remote name="a" fetch=".."/><default remote="a" revision="r"/>' 276 '<remote name="a" fetch=".."/><default remote="a" revision="r"/>'
275 '<project name="p" groups="local::me"/>' 277 '<project name="p" groups="local::me"/>'
@@ -277,16 +279,16 @@ class XmlManifestTests(ManifestParseTestCase):
277 '<project name="r" groups="keep"/>' 279 '<project name="r" groups="keep"/>'
278 "</manifest>" 280 "</manifest>"
279 ) 281 )
280 self.assertEqual( 282 assert (
281 sort_attributes(manifest.ToXml(omit_local=True).toxml()), 283 sort_attributes(manifest.ToXml(omit_local=True).toxml())
282 '<?xml version="1.0" ?><manifest>' 284 == '<?xml version="1.0" ?><manifest>'
283 '<remote fetch=".." name="a"/><default remote="a" revision="r"/>' 285 '<remote fetch=".." name="a"/><default remote="a" revision="r"/>'
284 '<project name="q"/><project groups="keep" name="r"/></manifest>', 286 '<project name="q"/><project groups="keep" name="r"/></manifest>'
285 ) 287 )
286 288
287 def test_toxml_with_local(self): 289 def test_toxml_with_local(self, repo_client: RepoClient) -> None:
288 """Does include local_manifests projects when omit_local=False.""" 290 """Does include local_manifests projects when omit_local=False."""
289 manifest = self.getXmlManifest( 291 manifest = repo_client.get_xml_manifest(
290 '<?xml version="1.0" encoding="UTF-8"?><manifest>' 292 '<?xml version="1.0" encoding="UTF-8"?><manifest>'
291 '<remote name="a" fetch=".."/><default remote="a" revision="r"/>' 293 '<remote name="a" fetch=".."/><default remote="a" revision="r"/>'
292 '<project name="p" groups="local::me"/>' 294 '<project name="p" groups="local::me"/>'
@@ -294,17 +296,17 @@ class XmlManifestTests(ManifestParseTestCase):
294 '<project name="r" groups="keep"/>' 296 '<project name="r" groups="keep"/>'
295 "</manifest>" 297 "</manifest>"
296 ) 298 )
297 self.assertEqual( 299 assert (
298 sort_attributes(manifest.ToXml(omit_local=False).toxml()), 300 sort_attributes(manifest.ToXml(omit_local=False).toxml())
299 '<?xml version="1.0" ?><manifest>' 301 == '<?xml version="1.0" ?><manifest>'
300 '<remote fetch=".." name="a"/><default remote="a" revision="r"/>' 302 '<remote fetch=".." name="a"/><default remote="a" revision="r"/>'
301 '<project groups="local::me" name="p"/>' 303 '<project groups="local::me" name="p"/>'
302 '<project name="q"/><project groups="keep" name="r"/></manifest>', 304 '<project name="q"/><project groups="keep" name="r"/></manifest>'
303 ) 305 )
304 306
305 def test_repo_hooks(self): 307 def test_repo_hooks(self, repo_client: RepoClient) -> None:
306 """Check repo-hooks settings.""" 308 """Check repo-hooks settings."""
307 manifest = self.getXmlManifest( 309 manifest = repo_client.get_xml_manifest(
308 """ 310 """
309<manifest> 311<manifest>
310 <remote name="test-remote" fetch="http://localhost" /> 312 <remote name="test-remote" fetch="http://localhost" />
@@ -314,14 +316,12 @@ class XmlManifestTests(ManifestParseTestCase):
314</manifest> 316</manifest>
315""" 317"""
316 ) 318 )
317 self.assertEqual(manifest.repo_hooks_project.name, "repohooks") 319 assert manifest.repo_hooks_project.name == "repohooks"
318 self.assertEqual( 320 assert manifest.repo_hooks_project.enabled_repo_hooks == ["a", "b"]
319 manifest.repo_hooks_project.enabled_repo_hooks, ["a", "b"]
320 )
321 321
322 def test_repo_hooks_unordered(self): 322 def test_repo_hooks_unordered(self, repo_client: RepoClient) -> None:
323 """Check repo-hooks settings work even if the project def comes second.""" # noqa: E501 323 """Check repo-hooks settings work when the project comes after."""
324 manifest = self.getXmlManifest( 324 manifest = repo_client.get_xml_manifest(
325 """ 325 """
326<manifest> 326<manifest>
327 <remote name="test-remote" fetch="http://localhost" /> 327 <remote name="test-remote" fetch="http://localhost" />
@@ -331,14 +331,12 @@ class XmlManifestTests(ManifestParseTestCase):
331</manifest> 331</manifest>
332""" 332"""
333 ) 333 )
334 self.assertEqual(manifest.repo_hooks_project.name, "repohooks") 334 assert manifest.repo_hooks_project.name == "repohooks"
335 self.assertEqual( 335 assert manifest.repo_hooks_project.enabled_repo_hooks == ["a", "b"]
336 manifest.repo_hooks_project.enabled_repo_hooks, ["a", "b"]
337 )
338 336
339 def test_unknown_tags(self): 337 def test_unknown_tags(self, repo_client: RepoClient) -> None:
340 """Check superproject settings.""" 338 """Check superproject settings."""
341 manifest = self.getXmlManifest( 339 manifest = repo_client.get_xml_manifest(
342 """ 340 """
343<manifest> 341<manifest>
344 <remote name="test-remote" fetch="http://localhost" /> 342 <remote name="test-remote" fetch="http://localhost" />
@@ -349,20 +347,20 @@ class XmlManifestTests(ManifestParseTestCase):
349</manifest> 347</manifest>
350""" 348"""
351 ) 349 )
352 self.assertEqual(manifest.superproject.name, "superproject") 350 assert manifest.superproject.name == "superproject"
353 self.assertEqual(manifest.superproject.remote.name, "test-remote") 351 assert manifest.superproject.remote.name == "test-remote"
354 self.assertEqual( 352 assert (
355 sort_attributes(manifest.ToXml().toxml()), 353 sort_attributes(manifest.ToXml().toxml())
356 '<?xml version="1.0" ?><manifest>' 354 == '<?xml version="1.0" ?><manifest>'
357 '<remote fetch="http://localhost" name="test-remote"/>' 355 '<remote fetch="http://localhost" name="test-remote"/>'
358 '<default remote="test-remote" revision="refs/heads/main"/>' 356 '<default remote="test-remote" revision="refs/heads/main"/>'
359 '<superproject name="superproject"/>' 357 '<superproject name="superproject"/>'
360 "</manifest>", 358 "</manifest>"
361 ) 359 )
362 360
363 def test_remote_annotations(self): 361 def test_remote_annotations(self, repo_client: RepoClient) -> None:
364 """Check remote settings.""" 362 """Check remote settings."""
365 manifest = self.getXmlManifest( 363 manifest = repo_client.get_xml_manifest(
366 """ 364 """
367<manifest> 365<manifest>
368 <remote name="test-remote" fetch="http://localhost"> 366 <remote name="test-remote" fetch="http://localhost">
@@ -371,24 +369,20 @@ class XmlManifestTests(ManifestParseTestCase):
371</manifest> 369</manifest>
372""" 370"""
373 ) 371 )
374 self.assertEqual( 372 assert manifest.remotes["test-remote"].annotations[0].name == "foo"
375 manifest.remotes["test-remote"].annotations[0].name, "foo" 373 assert manifest.remotes["test-remote"].annotations[0].value == "bar"
376 ) 374 assert (
377 self.assertEqual( 375 sort_attributes(manifest.ToXml().toxml())
378 manifest.remotes["test-remote"].annotations[0].value, "bar" 376 == '<?xml version="1.0" ?><manifest>'
379 )
380 self.assertEqual(
381 sort_attributes(manifest.ToXml().toxml()),
382 '<?xml version="1.0" ?><manifest>'
383 '<remote fetch="http://localhost" name="test-remote">' 377 '<remote fetch="http://localhost" name="test-remote">'
384 '<annotation name="foo" value="bar"/>' 378 '<annotation name="foo" value="bar"/>'
385 "</remote>" 379 "</remote>"
386 "</manifest>", 380 "</manifest>"
387 ) 381 )
388 382
389 def test_parse_with_xml_doctype(self): 383 def test_parse_with_xml_doctype(self, repo_client: RepoClient) -> None:
390 """Check correct manifest parse with DOCTYPE node present.""" 384 """Check correct manifest parse with DOCTYPE node present."""
391 manifest = self.getXmlManifest( 385 manifest = repo_client.get_xml_manifest(
392 """<?xml version="1.0" encoding="UTF-8"?> 386 """<?xml version="1.0" encoding="UTF-8"?>
393<!DOCTYPE manifest []> 387<!DOCTYPE manifest []>
394<manifest> 388<manifest>
@@ -398,42 +392,41 @@ class XmlManifestTests(ManifestParseTestCase):
398</manifest> 392</manifest>
399""" 393"""
400 ) 394 )
401 self.assertEqual(len(manifest.projects), 1) 395 assert len(manifest.projects) == 1
402 self.assertEqual(manifest.projects[0].name, "test-project") 396 assert manifest.projects[0].name == "test-project"
403 397
404 def test_sync_j_max(self): 398 def test_sync_j_max(self, repo_client: RepoClient) -> None:
405 """Check sync-j-max handling.""" 399 """Check sync-j-max handling."""
406 # Check valid value. 400 # Check valid value.
407 manifest = self.getXmlManifest( 401 manifest = repo_client.get_xml_manifest(
408 '<manifest><default sync-j-max="5" /></manifest>' 402 '<manifest><default sync-j-max="5" /></manifest>'
409 ) 403 )
410 self.assertEqual(manifest.default.sync_j_max, 5) 404 assert manifest.default.sync_j_max == 5
411 self.assertEqual( 405 assert (
412 manifest.ToXml().toxml(), 406 manifest.ToXml().toxml() == '<?xml version="1.0" ?>'
413 '<?xml version="1.0" ?>' 407 '<manifest><default sync-j-max="5"/></manifest>'
414 '<manifest><default sync-j-max="5"/></manifest>',
415 ) 408 )
416 409
417 # Check invalid values. 410 # Check invalid values.
418 with self.assertRaises(error.ManifestParseError): 411 with pytest.raises(error.ManifestParseError):
419 manifest = self.getXmlManifest( 412 manifest = repo_client.get_xml_manifest(
420 '<manifest><default sync-j-max="0" /></manifest>' 413 '<manifest><default sync-j-max="0" /></manifest>'
421 ) 414 )
422 manifest.ToXml() 415 manifest.ToXml()
423 416
424 with self.assertRaises(error.ManifestParseError): 417 with pytest.raises(error.ManifestParseError):
425 manifest = self.getXmlManifest( 418 manifest = repo_client.get_xml_manifest(
426 '<manifest><default sync-j-max="-1" /></manifest>' 419 '<manifest><default sync-j-max="-1" /></manifest>'
427 ) 420 )
428 manifest.ToXml() 421 manifest.ToXml()
429 422
430 423
431class IncludeElementTests(ManifestParseTestCase): 424class TestIncludeElement:
432 """Tests for <include>.""" 425 """Tests for <include>."""
433 426
434 def test_revision_default(self): 427 def test_revision_default(self, repo_client: RepoClient) -> None:
435 """Check handling of revision attribute.""" 428 """Check handling of revision attribute."""
436 root_m = self.manifest_dir / "root.xml" 429 root_m = repo_client.manifest_dir / "root.xml"
437 root_m.write_text( 430 root_m.write_text(
438 """ 431 """
439<manifest> 432<manifest>
@@ -445,7 +438,7 @@ class IncludeElementTests(ManifestParseTestCase):
445</manifest> 438</manifest>
446""" 439"""
447 ) 440 )
448 (self.manifest_dir / "stable.xml").write_text( 441 (repo_client.manifest_dir / "stable.xml").write_text(
449 """ 442 """
450<manifest> 443<manifest>
451 <include name="man1.xml" /> 444 <include name="man1.xml" />
@@ -455,7 +448,7 @@ class IncludeElementTests(ManifestParseTestCase):
455</manifest> 448</manifest>
456""" 449"""
457 ) 450 )
458 (self.manifest_dir / "man1.xml").write_text( 451 (repo_client.manifest_dir / "man1.xml").write_text(
459 """ 452 """
460<manifest> 453<manifest>
461 <project name="man1-name1" /> 454 <project name="man1-name1" />
@@ -463,7 +456,7 @@ class IncludeElementTests(ManifestParseTestCase):
463</manifest> 456</manifest>
464""" 457"""
465 ) 458 )
466 (self.manifest_dir / "man2.xml").write_text( 459 (repo_client.manifest_dir / "man2.xml").write_text(
467 """ 460 """
468<manifest> 461<manifest>
469 <project name="man2-name1" /> 462 <project name="man2-name1" />
@@ -471,31 +464,34 @@ class IncludeElementTests(ManifestParseTestCase):
471</manifest> 464</manifest>
472""" 465"""
473 ) 466 )
474 include_m = manifest_xml.XmlManifest(str(self.repodir), str(root_m)) 467 include_m = manifest_xml.XmlManifest(
468 str(repo_client.repodir), str(root_m)
469 )
475 for proj in include_m.projects: 470 for proj in include_m.projects:
476 if proj.name == "root-name1": 471 if proj.name == "root-name1":
477 # Check include revision not set on root level proj. 472 # Check include revision not set on root level proj.
478 self.assertNotEqual("stable-branch", proj.revisionExpr) 473 assert proj.revisionExpr != "stable-branch"
479 if proj.name == "root-name2": 474 if proj.name == "root-name2":
480 # Check root proj revision not removed. 475 # Check root proj revision not removed.
481 self.assertEqual("refs/heads/main", proj.revisionExpr) 476 assert proj.revisionExpr == "refs/heads/main"
482 if proj.name == "stable-name1": 477 if proj.name == "stable-name1":
483 # Check stable proj has inherited revision include node. 478 # Check stable proj has inherited revision include node.
484 self.assertEqual("stable-branch", proj.revisionExpr) 479 assert proj.revisionExpr == "stable-branch"
485 if proj.name == "stable-name2": 480 if proj.name == "stable-name2":
486 # Check stable proj revision can override include node. 481 # Check stable proj revision can override include node.
487 self.assertEqual("stable-branch2", proj.revisionExpr) 482 assert proj.revisionExpr == "stable-branch2"
488 if proj.name == "man1-name1": 483 if proj.name == "man1-name1":
489 self.assertEqual("stable-branch", proj.revisionExpr) 484 assert proj.revisionExpr == "stable-branch"
490 if proj.name == "man1-name2": 485 if proj.name == "man1-name2":
491 self.assertEqual("stable-branch3", proj.revisionExpr) 486 assert proj.revisionExpr == "stable-branch3"
492 if proj.name == "man2-name1": 487 if proj.name == "man2-name1":
493 self.assertEqual("stable-branch2", proj.revisionExpr) 488 assert proj.revisionExpr == "stable-branch2"
494 if proj.name == "man2-name2": 489 if proj.name == "man2-name2":
495 self.assertEqual("stable-branch3", proj.revisionExpr) 490 assert proj.revisionExpr == "stable-branch3"
496 491
497 def test_group_levels(self): 492 def test_group_levels(self, repo_client: RepoClient) -> None:
498 root_m = self.manifest_dir / "root.xml" 493 """Check handling of nested include groups."""
494 root_m = repo_client.manifest_dir / "root.xml"
499 root_m.write_text( 495 root_m.write_text(
500 """ 496 """
501<manifest> 497<manifest>
@@ -507,7 +503,7 @@ class IncludeElementTests(ManifestParseTestCase):
507</manifest> 503</manifest>
508""" 504"""
509 ) 505 )
510 (self.manifest_dir / "level1.xml").write_text( 506 (repo_client.manifest_dir / "level1.xml").write_text(
511 """ 507 """
512<manifest> 508<manifest>
513 <include name="level2.xml" groups="level2-group" /> 509 <include name="level2.xml" groups="level2-group" />
@@ -515,33 +511,38 @@ class IncludeElementTests(ManifestParseTestCase):
515</manifest> 511</manifest>
516""" 512"""
517 ) 513 )
518 (self.manifest_dir / "level2.xml").write_text( 514 (repo_client.manifest_dir / "level2.xml").write_text(
519 """ 515 """
520<manifest> 516<manifest>
521 <project name="level2-name1" path="level2-path1" groups="l2g1,l2g2" /> 517 <project name="level2-name1" path="level2-path1" groups="l2g1,l2g2" />
522</manifest> 518</manifest>
523""" 519"""
524 ) 520 )
525 include_m = manifest_xml.XmlManifest(str(self.repodir), str(root_m)) 521 include_m = manifest_xml.XmlManifest(
522 str(repo_client.repodir), str(root_m)
523 )
526 for proj in include_m.projects: 524 for proj in include_m.projects:
527 if proj.name == "root-name1": 525 if proj.name == "root-name1":
528 # Check include group not set on root level proj. 526 # Check include group not set on root level proj.
529 self.assertNotIn("level1-group", proj.groups) 527 assert "level1-group" not in proj.groups
530 if proj.name == "root-name2": 528 if proj.name == "root-name2":
531 # Check root proj group not removed. 529 # Check root proj group not removed.
532 self.assertIn("r2g1", proj.groups) 530 assert "r2g1" in proj.groups
533 if proj.name == "level1-name1": 531 if proj.name == "level1-name1":
534 # Check level1 proj has inherited group level 1. 532 # Check level1 proj has inherited group level 1.
535 self.assertIn("level1-group", proj.groups) 533 assert "level1-group" in proj.groups
536 if proj.name == "level2-name1": 534 if proj.name == "level2-name1":
537 # Check level2 proj has inherited group levels 1 and 2. 535 # Check level2 proj has inherited group levels 1 and 2.
538 self.assertIn("level1-group", proj.groups) 536 assert "level1-group" in proj.groups
539 self.assertIn("level2-group", proj.groups) 537 assert "level2-group" in proj.groups
540 # Check level2 proj group not removed. 538 # Check level2 proj group not removed.
541 self.assertIn("l2g1", proj.groups) 539 assert "l2g1" in proj.groups
542 540
543 def test_group_levels_with_extend_project(self): 541 def test_group_levels_with_extend_project(
544 root_m = self.manifest_dir / "root.xml" 542 self, repo_client: RepoClient
543 ) -> None:
544 """Check inheritance of groups via extend-project."""
545 root_m = repo_client.manifest_dir / "root.xml"
545 root_m.write_text( 546 root_m.write_text(
546 """ 547 """
547<manifest> 548<manifest>
@@ -552,32 +553,36 @@ class IncludeElementTests(ManifestParseTestCase):
552</manifest> 553</manifest>
553""" 554"""
554 ) 555 )
555 (self.manifest_dir / "man1.xml").write_text( 556 (repo_client.manifest_dir / "man1.xml").write_text(
556 """ 557 """
557<manifest> 558<manifest>
558 <project name="project1" path="project1" /> 559 <project name="project1" path="project1" />
559</manifest> 560</manifest>
560""" 561"""
561 ) 562 )
562 (self.manifest_dir / "man2.xml").write_text( 563 (repo_client.manifest_dir / "man2.xml").write_text(
563 """ 564 """
564<manifest> 565<manifest>
565 <extend-project name="project1" groups="eg1" /> 566 <extend-project name="project1" groups="eg1" />
566</manifest> 567</manifest>
567""" 568"""
568 ) 569 )
569 include_m = manifest_xml.XmlManifest(str(self.repodir), str(root_m)) 570 include_m = manifest_xml.XmlManifest(
571 str(repo_client.repodir), str(root_m)
572 )
570 proj = include_m.projects[0] 573 proj = include_m.projects[0]
571 # Check project has inherited group via project element. 574 # Check project has inherited group via project element.
572 self.assertIn("top-group1", proj.groups) 575 assert "top-group1" in proj.groups
573 # Check project has inherited group via extend-project element. 576 # Check project has inherited group via extend-project element.
574 self.assertIn("top-group2", proj.groups) 577 assert "top-group2" in proj.groups
575 # Check project has set group via extend-project element. 578 # Check project has set group via extend-project element.
576 self.assertIn("eg1", proj.groups) 579 assert "eg1" in proj.groups
577 580
578 def test_extend_project_does_not_inherit_local_groups(self): 581 def test_extend_project_does_not_inherit_local_groups(
582 self, repo_client: RepoClient
583 ) -> None:
579 """Check that extend-project does not inherit local groups.""" 584 """Check that extend-project does not inherit local groups."""
580 root_m = self.manifest_dir / "root.xml" 585 root_m = repo_client.manifest_dir / "root.xml"
581 root_m.write_text( 586 root_m.write_text(
582 """ 587 """
583<manifest> 588<manifest>
@@ -588,26 +593,28 @@ class IncludeElementTests(ManifestParseTestCase):
588</manifest> 593</manifest>
589""" 594"""
590 ) 595 )
591 (self.manifest_dir / "man1.xml").write_text( 596 (repo_client.manifest_dir / "man1.xml").write_text(
592 """ 597 """
593<manifest> 598<manifest>
594 <extend-project name="project1" groups="g3" /> 599 <extend-project name="project1" groups="g3" />
595</manifest> 600</manifest>
596""" 601"""
597 ) 602 )
598 include_m = manifest_xml.XmlManifest(str(self.repodir), str(root_m)) 603 include_m = manifest_xml.XmlManifest(
604 str(repo_client.repodir), str(root_m)
605 )
599 proj = include_m.projects[0] 606 proj = include_m.projects[0]
600 607
601 self.assertIn("g1", proj.groups) 608 assert "g1" in proj.groups
602 self.assertNotIn("local:g2", proj.groups) 609 assert "local:g2" not in proj.groups
603 self.assertIn("g3", proj.groups) 610 assert "g3" in proj.groups
604 611
605 def test_allow_bad_name_from_user(self): 612 def test_allow_bad_name_from_user(self, repo_client: RepoClient) -> None:
606 """Check handling of bad name attribute from the user's input.""" 613 """Check handling of bad name attribute from the user's input."""
607 614
608 def parse(name): 615 def parse(name: str) -> None:
609 name = self.encodeXmlAttr(name) 616 name = repo_client.encode_xml_attr(name)
610 manifest = self.getXmlManifest( 617 manifest = repo_client.get_xml_manifest(
611 f""" 618 f"""
612<manifest> 619<manifest>
613 <remote name="default-remote" fetch="http://localhost" /> 620 <remote name="default-remote" fetch="http://localhost" />
@@ -620,26 +627,26 @@ class IncludeElementTests(ManifestParseTestCase):
620 manifest.ToXml() 627 manifest.ToXml()
621 628
622 # Setup target of the include. 629 # Setup target of the include.
623 target = self.tempdir / "target.xml" 630 target = repo_client.topdir / "target.xml"
624 target.write_text("<manifest></manifest>") 631 target.write_text("<manifest></manifest>")
625 632
626 # Include with absolute path. 633 # Include with absolute path.
627 parse(os.path.abspath(target)) 634 parse(str(target.absolute()))
628 635
629 # Include with relative path. 636 # Include with relative path.
630 parse(os.path.relpath(target, self.manifest_dir)) 637 parse(os.path.relpath(str(target), str(repo_client.manifest_dir)))
631 638
632 def test_bad_name_checks(self): 639 def test_bad_name_checks(self, repo_client: RepoClient) -> None:
633 """Check handling of bad name attribute.""" 640 """Check handling of bad name attribute."""
634 641
635 def parse(name): 642 def parse(name: str) -> None:
636 name = self.encodeXmlAttr(name) 643 name = repo_client.encode_xml_attr(name)
637 # Setup target of the include. 644 # Setup target of the include.
638 (self.manifest_dir / "target.xml").write_text( 645 (repo_client.manifest_dir / "target.xml").write_text(
639 f'<manifest><include name="{name}"/></manifest>' 646 f'<manifest><include name="{name}"/></manifest>'
640 ) 647 )
641 648
642 manifest = self.getXmlManifest( 649 manifest = repo_client.get_xml_manifest(
643 """ 650 """
644<manifest> 651<manifest>
645 <remote name="default-remote" fetch="http://localhost" /> 652 <remote name="default-remote" fetch="http://localhost" />
@@ -652,23 +659,23 @@ class IncludeElementTests(ManifestParseTestCase):
652 manifest.ToXml() 659 manifest.ToXml()
653 660
654 # Handle empty name explicitly because a different codepath rejects it. 661 # Handle empty name explicitly because a different codepath rejects it.
655 with self.assertRaises(error.ManifestParseError): 662 with pytest.raises(error.ManifestParseError):
656 parse("") 663 parse("")
657 664
658 for path in INVALID_FS_PATHS: 665 for path in INVALID_FS_PATHS:
659 if not path: 666 if not path:
660 continue 667 continue
661 668
662 with self.assertRaises(error.ManifestInvalidPathError): 669 with pytest.raises(error.ManifestInvalidPathError):
663 parse(path) 670 parse(path)
664 671
665 672
666class ProjectElementTests(ManifestParseTestCase): 673class TestProjectElement:
667 """Tests for <project>.""" 674 """Tests for <project>."""
668 675
669 def test_group(self): 676 def test_group(self, repo_client: RepoClient) -> None:
670 """Check project group settings.""" 677 """Check project group settings."""
671 manifest = self.getXmlManifest( 678 manifest = repo_client.get_xml_manifest(
672 """ 679 """
673<manifest> 680<manifest>
674 <remote name="test-remote" fetch="http://localhost" /> 681 <remote name="test-remote" fetch="http://localhost" />
@@ -678,28 +685,33 @@ class ProjectElementTests(ManifestParseTestCase):
678</manifest> 685</manifest>
679""" 686"""
680 ) 687 )
681 self.assertEqual(len(manifest.projects), 2) 688 assert len(manifest.projects) == 2
682 # Ordering isn't guaranteed. 689 # Ordering isn't guaranteed.
683 result = { 690 result = {
684 manifest.projects[0].name: manifest.projects[0].groups, 691 manifest.projects[0].name: manifest.projects[0].groups,
685 manifest.projects[1].name: manifest.projects[1].groups, 692 manifest.projects[1].name: manifest.projects[1].groups,
686 } 693 }
687 self.assertEqual( 694 assert result["test-name"] == {
688 result["test-name"], {"name:test-name", "all", "path:test-path"} 695 "name:test-name",
689 ) 696 "all",
690 self.assertEqual( 697 "path:test-path",
691 result["extras"], 698 }
692 {"g1", "g2", "name:extras", "all", "path:path"}, 699 assert result["extras"] == {
693 ) 700 "g1",
701 "g2",
702 "name:extras",
703 "all",
704 "path:path",
705 }
694 groupstr = "default,platform-" + platform.system().lower() 706 groupstr = "default,platform-" + platform.system().lower()
695 self.assertEqual(groupstr, manifest.GetManifestGroupsStr()) 707 assert manifest.GetManifestGroupsStr() == groupstr
696 groupstr = "g1,g2,g1" 708 groupstr = "g1,g2,g1"
697 manifest.manifestProject.config.SetString("manifest.groups", groupstr) 709 manifest.manifestProject.config.SetString("manifest.groups", groupstr)
698 self.assertEqual(groupstr, manifest.GetManifestGroupsStr()) 710 assert manifest.GetManifestGroupsStr() == groupstr
699 711
700 def test_set_revision_id(self): 712 def test_set_revision_id(self, repo_client: RepoClient) -> None:
701 """Check setting of project's revisionId.""" 713 """Check setting of project's revisionId."""
702 manifest = self.getXmlManifest( 714 manifest = repo_client.get_xml_manifest(
703 """ 715 """
704<manifest> 716<manifest>
705 <remote name="default-remote" fetch="http://localhost" /> 717 <remote name="default-remote" fetch="http://localhost" />
@@ -708,25 +720,48 @@ class ProjectElementTests(ManifestParseTestCase):
708</manifest> 720</manifest>
709""" 721"""
710 ) 722 )
711 self.assertEqual(len(manifest.projects), 1) 723 assert len(manifest.projects) == 1
712 project = manifest.projects[0] 724 project = manifest.projects[0]
713 project.SetRevisionId("ABCDEF") 725 project.SetRevisionId("ABCDEF")
714 self.assertEqual( 726 assert (
715 sort_attributes(manifest.ToXml().toxml()), 727 sort_attributes(manifest.ToXml().toxml())
716 '<?xml version="1.0" ?><manifest>' 728 == '<?xml version="1.0" ?><manifest>'
717 '<remote fetch="http://localhost" name="default-remote"/>' 729 '<remote fetch="http://localhost" name="default-remote"/>'
718 '<default remote="default-remote" revision="refs/heads/main"/>' 730 '<default remote="default-remote" revision="refs/heads/main"/>'
719 '<project name="test-name" revision="ABCDEF" upstream="refs/heads/main"/>' # noqa: E501 731 '<project name="test-name" revision="ABCDEF" upstream="refs/heads/main"/>' # noqa: E501
720 "</manifest>", 732 "</manifest>"
721 ) 733 )
722 734
723 def test_trailing_slash(self): 735 def test_sync_strategy(self, repo_client: RepoClient) -> None:
736 """Check setting of project's sync_strategy."""
737 manifest = repo_client.get_xml_manifest(
738 """
739<manifest>
740 <remote name="default-remote" fetch="http://localhost" />
741 <default remote="default-remote" revision="refs/heads/main" />
742 <project name="test-name" sync-strategy="stateless"/>
743</manifest>
744"""
745 )
746 assert len(manifest.projects) == 1
747 project = manifest.projects[0]
748 assert project.sync_strategy == "stateless"
749 assert (
750 sort_attributes(manifest.ToXml().toxml())
751 == '<?xml version="1.0" ?><manifest>'
752 '<remote fetch="http://localhost" name="default-remote"/>'
753 '<default remote="default-remote" revision="refs/heads/main"/>'
754 '<project name="test-name" sync-strategy="stateless"/>'
755 "</manifest>"
756 )
757
758 def test_trailing_slash(self, repo_client: RepoClient) -> None:
724 """Check handling of trailing slashes in attributes.""" 759 """Check handling of trailing slashes in attributes."""
725 760
726 def parse(name, path): 761 def parse(name: str, path: str) -> manifest_xml.XmlManifest:
727 name = self.encodeXmlAttr(name) 762 name = repo_client.encode_xml_attr(name)
728 path = self.encodeXmlAttr(path) 763 path = repo_client.encode_xml_attr(path)
729 return self.getXmlManifest( 764 return repo_client.get_xml_manifest(
730 f""" 765 f"""
731<manifest> 766<manifest>
732 <remote name="default-remote" fetch="http://localhost" /> 767 <remote name="default-remote" fetch="http://localhost" />
@@ -737,48 +772,36 @@ class ProjectElementTests(ManifestParseTestCase):
737 ) 772 )
738 773
739 manifest = parse("a/path/", "foo") 774 manifest = parse("a/path/", "foo")
740 self.assertEqual( 775 assert os.path.normpath(manifest.projects[0].gitdir) == os.path.join(
741 os.path.normpath(manifest.projects[0].gitdir), 776 str(repo_client.topdir), ".repo", "projects", "foo.git"
742 os.path.join(self.tempdir, ".repo", "projects", "foo.git"),
743 ) 777 )
744 self.assertEqual( 778 assert os.path.normpath(manifest.projects[0].objdir) == os.path.join(
745 os.path.normpath(manifest.projects[0].objdir), 779 str(repo_client.topdir), ".repo", "project-objects", "a", "path.git"
746 os.path.join(
747 self.tempdir, ".repo", "project-objects", "a", "path.git"
748 ),
749 ) 780 )
750 781
751 manifest = parse("a/path", "foo/") 782 manifest = parse("a/path", "foo/")
752 self.assertEqual( 783 assert os.path.normpath(manifest.projects[0].gitdir) == os.path.join(
753 os.path.normpath(manifest.projects[0].gitdir), 784 str(repo_client.topdir), ".repo", "projects", "foo.git"
754 os.path.join(self.tempdir, ".repo", "projects", "foo.git"),
755 ) 785 )
756 self.assertEqual( 786 assert os.path.normpath(manifest.projects[0].objdir) == os.path.join(
757 os.path.normpath(manifest.projects[0].objdir), 787 str(repo_client.topdir), ".repo", "project-objects", "a", "path.git"
758 os.path.join(
759 self.tempdir, ".repo", "project-objects", "a", "path.git"
760 ),
761 ) 788 )
762 789
763 manifest = parse("a/path", "foo//////") 790 manifest = parse("a/path", "foo//////")
764 self.assertEqual( 791 assert os.path.normpath(manifest.projects[0].gitdir) == os.path.join(
765 os.path.normpath(manifest.projects[0].gitdir), 792 str(repo_client.topdir), ".repo", "projects", "foo.git"
766 os.path.join(self.tempdir, ".repo", "projects", "foo.git"),
767 ) 793 )
768 self.assertEqual( 794 assert os.path.normpath(manifest.projects[0].objdir) == os.path.join(
769 os.path.normpath(manifest.projects[0].objdir), 795 str(repo_client.topdir), ".repo", "project-objects", "a", "path.git"
770 os.path.join(
771 self.tempdir, ".repo", "project-objects", "a", "path.git"
772 ),
773 ) 796 )
774 797
775 def test_toplevel_path(self): 798 def test_toplevel_path(self, repo_client: RepoClient) -> None:
776 """Check handling of path=. specially.""" 799 """Check handling of path=. specially."""
777 800
778 def parse(name, path): 801 def parse(name: str, path: str) -> manifest_xml.XmlManifest:
779 name = self.encodeXmlAttr(name) 802 name = repo_client.encode_xml_attr(name)
780 path = self.encodeXmlAttr(path) 803 path = repo_client.encode_xml_attr(path)
781 return self.getXmlManifest( 804 return repo_client.get_xml_manifest(
782 f""" 805 f"""
783<manifest> 806<manifest>
784 <remote name="default-remote" fetch="http://localhost" /> 807 <remote name="default-remote" fetch="http://localhost" />
@@ -790,18 +813,19 @@ class ProjectElementTests(ManifestParseTestCase):
790 813
791 for path in (".", "./", ".//", ".///"): 814 for path in (".", "./", ".//", ".///"):
792 manifest = parse("server/path", path) 815 manifest = parse("server/path", path)
793 self.assertEqual( 816 assert os.path.normpath(
794 os.path.normpath(manifest.projects[0].gitdir), 817 manifest.projects[0].gitdir
795 os.path.join(self.tempdir, ".repo", "projects", "..git"), 818 ) == os.path.join(
819 str(repo_client.topdir), ".repo", "projects", "..git"
796 ) 820 )
797 821
798 def test_bad_path_name_checks(self): 822 def test_bad_path_name_checks(self, repo_client: RepoClient) -> None:
799 """Check handling of bad path & name attributes.""" 823 """Check handling of bad path & name attributes."""
800 824
801 def parse(name, path): 825 def parse(name: str, path: str) -> None:
802 name = self.encodeXmlAttr(name) 826 name = repo_client.encode_xml_attr(name)
803 path = self.encodeXmlAttr(path) 827 path = repo_client.encode_xml_attr(path)
804 manifest = self.getXmlManifest( 828 manifest = repo_client.get_xml_manifest(
805 f""" 829 f"""
806<manifest> 830<manifest>
807 <remote name="default-remote" fetch="http://localhost" /> 831 <remote name="default-remote" fetch="http://localhost" />
@@ -818,28 +842,28 @@ class ProjectElementTests(ManifestParseTestCase):
818 842
819 # Handle empty name explicitly because a different codepath rejects it. 843 # Handle empty name explicitly because a different codepath rejects it.
820 # Empty path is OK because it defaults to the name field. 844 # Empty path is OK because it defaults to the name field.
821 with self.assertRaises(error.ManifestParseError): 845 with pytest.raises(error.ManifestParseError):
822 parse("", "ok") 846 parse("", "ok")
823 847
824 for path in INVALID_FS_PATHS: 848 for path in INVALID_FS_PATHS:
825 if not path or path.endswith("/") or path.endswith(os.path.sep): 849 if not path or path.endswith("/") or path.endswith(os.path.sep):
826 continue 850 continue
827 851
828 with self.assertRaises(error.ManifestInvalidPathError): 852 with pytest.raises(error.ManifestInvalidPathError):
829 parse(path, "ok") 853 parse(path, "ok")
830 854
831 # We have a dedicated test for path=".". 855 # We have a dedicated test for path=".".
832 if path not in {"."}: 856 if path not in {"."}:
833 with self.assertRaises(error.ManifestInvalidPathError): 857 with pytest.raises(error.ManifestInvalidPathError):
834 parse("ok", path) 858 parse("ok", path)
835 859
836 860
837class SuperProjectElementTests(ManifestParseTestCase): 861class TestSuperProjectElement:
838 """Tests for <superproject>.""" 862 """Tests for <superproject>."""
839 863
840 def test_superproject(self): 864 def test_superproject(self, repo_client: RepoClient) -> None:
841 """Check superproject settings.""" 865 """Check superproject settings."""
842 manifest = self.getXmlManifest( 866 manifest = repo_client.get_xml_manifest(
843 """ 867 """
844<manifest> 868<manifest>
845 <remote name="test-remote" fetch="http://localhost" /> 869 <remote name="test-remote" fetch="http://localhost" />
@@ -848,25 +872,24 @@ class SuperProjectElementTests(ManifestParseTestCase):
848</manifest> 872</manifest>
849""" 873"""
850 ) 874 )
851 self.assertEqual(manifest.superproject.name, "superproject") 875 assert manifest.superproject.name == "superproject"
852 self.assertEqual(manifest.superproject.remote.name, "test-remote") 876 assert manifest.superproject.remote.name == "test-remote"
853 self.assertEqual( 877 assert (
854 manifest.superproject.remote.url, "http://localhost/superproject" 878 manifest.superproject.remote.url == "http://localhost/superproject"
855 ) 879 )
856 self.assertEqual(manifest.superproject.revision, "refs/heads/main") 880 assert manifest.superproject.revision == "refs/heads/main"
857 self.assertEqual( 881 assert (
858 sort_attributes(manifest.ToXml().toxml()), 882 sort_attributes(manifest.ToXml().toxml())
859 '<?xml version="1.0" ?><manifest>' 883 == '<?xml version="1.0" ?><manifest>'
860 '<remote fetch="http://localhost" name="test-remote"/>' 884 '<remote fetch="http://localhost" name="test-remote"/>'
861 '<default remote="test-remote" revision="refs/heads/main"/>' 885 '<default remote="test-remote" revision="refs/heads/main"/>'
862 '<superproject name="superproject"/>' 886 '<superproject name="superproject"/>'
863 "</manifest>", 887 "</manifest>"
864 ) 888 )
865 889
866 def test_superproject_revision(self): 890 def test_superproject_revision(self, repo_client: RepoClient) -> None:
867 """Check superproject settings with a different revision attribute""" 891 """Check superproject settings with a different revision attribute"""
868 self.maxDiff = None 892 manifest = repo_client.get_xml_manifest(
869 manifest = self.getXmlManifest(
870 """ 893 """
871<manifest> 894<manifest>
872 <remote name="test-remote" fetch="http://localhost" /> 895 <remote name="test-remote" fetch="http://localhost" />
@@ -875,25 +898,26 @@ class SuperProjectElementTests(ManifestParseTestCase):
875</manifest> 898</manifest>
876""" 899"""
877 ) 900 )
878 self.assertEqual(manifest.superproject.name, "superproject") 901 assert manifest.superproject.name == "superproject"
879 self.assertEqual(manifest.superproject.remote.name, "test-remote") 902 assert manifest.superproject.remote.name == "test-remote"
880 self.assertEqual( 903 assert (
881 manifest.superproject.remote.url, "http://localhost/superproject" 904 manifest.superproject.remote.url == "http://localhost/superproject"
882 ) 905 )
883 self.assertEqual(manifest.superproject.revision, "refs/heads/stable") 906 assert manifest.superproject.revision == "refs/heads/stable"
884 self.assertEqual( 907 assert (
885 sort_attributes(manifest.ToXml().toxml()), 908 sort_attributes(manifest.ToXml().toxml())
886 '<?xml version="1.0" ?><manifest>' 909 == '<?xml version="1.0" ?><manifest>'
887 '<remote fetch="http://localhost" name="test-remote"/>' 910 '<remote fetch="http://localhost" name="test-remote"/>'
888 '<default remote="test-remote" revision="refs/heads/main"/>' 911 '<default remote="test-remote" revision="refs/heads/main"/>'
889 '<superproject name="superproject" revision="refs/heads/stable"/>' 912 '<superproject name="superproject" revision="refs/heads/stable"/>'
890 "</manifest>", 913 "</manifest>"
891 ) 914 )
892 915
893 def test_superproject_revision_default_negative(self): 916 def test_superproject_revision_default_negative(
917 self, repo_client: RepoClient
918 ) -> None:
894 """Check superproject settings with a same revision attribute""" 919 """Check superproject settings with a same revision attribute"""
895 self.maxDiff = None 920 manifest = repo_client.get_xml_manifest(
896 manifest = self.getXmlManifest(
897 """ 921 """
898<manifest> 922<manifest>
899 <remote name="test-remote" fetch="http://localhost" /> 923 <remote name="test-remote" fetch="http://localhost" />
@@ -902,51 +926,53 @@ class SuperProjectElementTests(ManifestParseTestCase):
902</manifest> 926</manifest>
903""" 927"""
904 ) 928 )
905 self.assertEqual(manifest.superproject.name, "superproject") 929 assert manifest.superproject.name == "superproject"
906 self.assertEqual(manifest.superproject.remote.name, "test-remote") 930 assert manifest.superproject.remote.name == "test-remote"
907 self.assertEqual( 931 assert (
908 manifest.superproject.remote.url, "http://localhost/superproject" 932 manifest.superproject.remote.url == "http://localhost/superproject"
909 ) 933 )
910 self.assertEqual(manifest.superproject.revision, "refs/heads/stable") 934 assert manifest.superproject.revision == "refs/heads/stable"
911 self.assertEqual( 935 assert (
912 sort_attributes(manifest.ToXml().toxml()), 936 sort_attributes(manifest.ToXml().toxml())
913 '<?xml version="1.0" ?><manifest>' 937 == '<?xml version="1.0" ?><manifest>'
914 '<remote fetch="http://localhost" name="test-remote"/>' 938 '<remote fetch="http://localhost" name="test-remote"/>'
915 '<default remote="test-remote" revision="refs/heads/stable"/>' 939 '<default remote="test-remote" revision="refs/heads/stable"/>'
916 '<superproject name="superproject"/>' 940 '<superproject name="superproject"/>'
917 "</manifest>", 941 "</manifest>"
918 ) 942 )
919 943
920 def test_superproject_revision_remote(self): 944 def test_superproject_revision_remote(
945 self, repo_client: RepoClient
946 ) -> None:
921 """Check superproject settings with a same revision attribute""" 947 """Check superproject settings with a same revision attribute"""
922 self.maxDiff = None 948 manifest = repo_client.get_xml_manifest(
923 manifest = self.getXmlManifest(
924 """ 949 """
925<manifest> 950<manifest>
926 <remote name="test-remote" fetch="http://localhost" revision="refs/heads/main" /> 951 <remote name="test-remote" fetch="http://localhost"
952 revision="refs/heads/main" />
927 <default remote="test-remote" /> 953 <default remote="test-remote" />
928 <superproject name="superproject" revision="refs/heads/stable" /> 954 <superproject name="superproject" revision="refs/heads/stable" />
929</manifest> 955</manifest>
930""" # noqa: E501 956"""
931 ) 957 )
932 self.assertEqual(manifest.superproject.name, "superproject") 958 assert manifest.superproject.name == "superproject"
933 self.assertEqual(manifest.superproject.remote.name, "test-remote") 959 assert manifest.superproject.remote.name == "test-remote"
934 self.assertEqual( 960 assert (
935 manifest.superproject.remote.url, "http://localhost/superproject" 961 manifest.superproject.remote.url == "http://localhost/superproject"
936 ) 962 )
937 self.assertEqual(manifest.superproject.revision, "refs/heads/stable") 963 assert manifest.superproject.revision == "refs/heads/stable"
938 self.assertEqual( 964 assert (
939 sort_attributes(manifest.ToXml().toxml()), 965 sort_attributes(manifest.ToXml().toxml())
940 '<?xml version="1.0" ?><manifest>' 966 == '<?xml version="1.0" ?><manifest>'
941 '<remote fetch="http://localhost" name="test-remote" revision="refs/heads/main"/>' # noqa: E501 967 '<remote fetch="http://localhost" name="test-remote" revision="refs/heads/main"/>' # noqa: E501
942 '<default remote="test-remote"/>' 968 '<default remote="test-remote"/>'
943 '<superproject name="superproject" revision="refs/heads/stable"/>' 969 '<superproject name="superproject" revision="refs/heads/stable"/>'
944 "</manifest>", 970 "</manifest>"
945 ) 971 )
946 972
947 def test_remote(self): 973 def test_remote(self, repo_client: RepoClient) -> None:
948 """Check superproject settings with a remote.""" 974 """Check superproject settings with a remote."""
949 manifest = self.getXmlManifest( 975 manifest = repo_client.get_xml_manifest(
950 """ 976 """
951<manifest> 977<manifest>
952 <remote name="default-remote" fetch="http://localhost" /> 978 <remote name="default-remote" fetch="http://localhost" />
@@ -956,28 +982,26 @@ class SuperProjectElementTests(ManifestParseTestCase):
956</manifest> 982</manifest>
957""" 983"""
958 ) 984 )
959 self.assertEqual(manifest.superproject.name, "platform/superproject") 985 assert manifest.superproject.name == "platform/superproject"
960 self.assertEqual( 986 assert manifest.superproject.remote.name == "superproject-remote"
961 manifest.superproject.remote.name, "superproject-remote" 987 assert (
962 ) 988 manifest.superproject.remote.url
963 self.assertEqual( 989 == "http://localhost/platform/superproject"
964 manifest.superproject.remote.url,
965 "http://localhost/platform/superproject",
966 ) 990 )
967 self.assertEqual(manifest.superproject.revision, "refs/heads/main") 991 assert manifest.superproject.revision == "refs/heads/main"
968 self.assertEqual( 992 assert (
969 sort_attributes(manifest.ToXml().toxml()), 993 sort_attributes(manifest.ToXml().toxml())
970 '<?xml version="1.0" ?><manifest>' 994 == '<?xml version="1.0" ?><manifest>'
971 '<remote fetch="http://localhost" name="default-remote"/>' 995 '<remote fetch="http://localhost" name="default-remote"/>'
972 '<remote fetch="http://localhost" name="superproject-remote"/>' 996 '<remote fetch="http://localhost" name="superproject-remote"/>'
973 '<default remote="default-remote" revision="refs/heads/main"/>' 997 '<default remote="default-remote" revision="refs/heads/main"/>'
974 '<superproject name="platform/superproject" remote="superproject-remote"/>' # noqa: E501 998 '<superproject name="platform/superproject" remote="superproject-remote"/>' # noqa: E501
975 "</manifest>", 999 "</manifest>"
976 ) 1000 )
977 1001
978 def test_defalut_remote(self): 1002 def test_default_remote(self, repo_client: RepoClient) -> None:
979 """Check superproject settings with a default remote.""" 1003 """Check superproject settings with a default remote."""
980 manifest = self.getXmlManifest( 1004 manifest = repo_client.get_xml_manifest(
981 """ 1005 """
982<manifest> 1006<manifest>
983 <remote name="default-remote" fetch="http://localhost" /> 1007 <remote name="default-remote" fetch="http://localhost" />
@@ -986,62 +1010,61 @@ class SuperProjectElementTests(ManifestParseTestCase):
986</manifest> 1010</manifest>
987""" 1011"""
988 ) 1012 )
989 self.assertEqual(manifest.superproject.name, "superproject") 1013 assert manifest.superproject.name == "superproject"
990 self.assertEqual(manifest.superproject.remote.name, "default-remote") 1014 assert manifest.superproject.remote.name == "default-remote"
991 self.assertEqual(manifest.superproject.revision, "refs/heads/main") 1015 assert manifest.superproject.revision == "refs/heads/main"
992 self.assertEqual( 1016 assert (
993 sort_attributes(manifest.ToXml().toxml()), 1017 sort_attributes(manifest.ToXml().toxml())
994 '<?xml version="1.0" ?><manifest>' 1018 == '<?xml version="1.0" ?><manifest>'
995 '<remote fetch="http://localhost" name="default-remote"/>' 1019 '<remote fetch="http://localhost" name="default-remote"/>'
996 '<default remote="default-remote" revision="refs/heads/main"/>' 1020 '<default remote="default-remote" revision="refs/heads/main"/>'
997 '<superproject name="superproject"/>' 1021 '<superproject name="superproject"/>'
998 "</manifest>", 1022 "</manifest>"
999 ) 1023 )
1000 1024
1001 1025
1002class ContactinfoElementTests(ManifestParseTestCase): 1026class TestContactinfoElement:
1003 """Tests for <contactinfo>.""" 1027 """Tests for <contactinfo>."""
1004 1028
1005 def test_contactinfo(self): 1029 def test_contactinfo(self, repo_client: RepoClient) -> None:
1006 """Check contactinfo settings.""" 1030 """Check contactinfo settings."""
1007 bugurl = "http://localhost/contactinfo" 1031 bugurl = "http://localhost/contactinfo"
1008 manifest = self.getXmlManifest( 1032 manifest = repo_client.get_xml_manifest(
1009 f""" 1033 f"""
1010<manifest> 1034<manifest>
1011 <contactinfo bugurl="{bugurl}"/> 1035 <contactinfo bugurl="{bugurl}"/>
1012</manifest> 1036</manifest>
1013""" 1037"""
1014 ) 1038 )
1015 self.assertEqual(manifest.contactinfo.bugurl, bugurl) 1039 assert manifest.contactinfo.bugurl == bugurl
1016 self.assertEqual( 1040 assert (
1017 manifest.ToXml().toxml(), 1041 manifest.ToXml().toxml() == '<?xml version="1.0" ?><manifest>'
1018 '<?xml version="1.0" ?><manifest>'
1019 f'<contactinfo bugurl="{bugurl}"/>' 1042 f'<contactinfo bugurl="{bugurl}"/>'
1020 "</manifest>", 1043 "</manifest>"
1021 ) 1044 )
1022 1045
1023 1046
1024class DefaultElementTests(ManifestParseTestCase): 1047class TestDefaultElement:
1025 """Tests for <default>.""" 1048 """Tests for <default>."""
1026 1049
1027 def test_default(self): 1050 def test_default(self) -> None:
1028 """Check default settings.""" 1051 """Check default settings."""
1029 a = manifest_xml._Default() 1052 a = manifest_xml._Default()
1030 a.revisionExpr = "foo" 1053 a.revisionExpr = "foo"
1031 a.remote = manifest_xml._XmlRemote(name="remote") 1054 a.remote = manifest_xml._XmlRemote(name="remote")
1032 b = manifest_xml._Default() 1055 b = manifest_xml._Default()
1033 b.revisionExpr = "bar" 1056 b.revisionExpr = "bar"
1034 self.assertEqual(a, a) 1057 assert a == a
1035 self.assertNotEqual(a, b) 1058 assert a != b
1036 self.assertNotEqual(b, a.remote) 1059 assert b != a.remote
1037 self.assertNotEqual(a, 123) 1060 assert a != 123
1038 self.assertNotEqual(a, None) 1061 assert a is not None
1039 1062
1040 1063
1041class RemoteElementTests(ManifestParseTestCase): 1064class TestRemoteElement:
1042 """Tests for <remote>.""" 1065 """Tests for <remote>."""
1043 1066
1044 def test_remote(self): 1067 def test_remote(self) -> None:
1045 """Check remote settings.""" 1068 """Check remote settings."""
1046 a = manifest_xml._XmlRemote(name="foo") 1069 a = manifest_xml._XmlRemote(name="foo")
1047 a.AddAnnotation("key1", "value1", "true") 1070 a.AddAnnotation("key1", "value1", "true")
@@ -1051,20 +1074,21 @@ class RemoteElementTests(ManifestParseTestCase):
1051 c.AddAnnotation("key1", "value2", "true") 1074 c.AddAnnotation("key1", "value2", "true")
1052 d = manifest_xml._XmlRemote(name="foo") 1075 d = manifest_xml._XmlRemote(name="foo")
1053 d.AddAnnotation("key1", "value1", "false") 1076 d.AddAnnotation("key1", "value1", "false")
1054 self.assertEqual(a, a) 1077 assert a == a
1055 self.assertNotEqual(a, b) 1078 assert a != b
1056 self.assertNotEqual(a, c) 1079 assert a != c
1057 self.assertNotEqual(a, d) 1080 assert a != d
1058 self.assertNotEqual(a, manifest_xml._Default()) 1081 assert a != manifest_xml._Default()
1059 self.assertNotEqual(a, 123) 1082 assert a != 123
1060 self.assertNotEqual(a, None) 1083 assert a is not None
1061 1084
1062 1085
1063class RemoveProjectElementTests(ManifestParseTestCase): 1086class TestRemoveProjectElement:
1064 """Tests for <remove-project>.""" 1087 """Tests for <remove-project>."""
1065 1088
1066 def test_remove_one_project(self): 1089 def test_remove_one_project(self, repo_client: RepoClient) -> None:
1067 manifest = self.getXmlManifest( 1090 """Check removal of a single project."""
1091 manifest = repo_client.get_xml_manifest(
1068 """ 1092 """
1069<manifest> 1093<manifest>
1070 <remote name="default-remote" fetch="http://localhost" /> 1094 <remote name="default-remote" fetch="http://localhost" />
@@ -1074,10 +1098,13 @@ class RemoveProjectElementTests(ManifestParseTestCase):
1074</manifest> 1098</manifest>
1075""" 1099"""
1076 ) 1100 )
1077 self.assertEqual(manifest.projects, []) 1101 assert manifest.projects == []
1078 1102
1079 def test_remove_one_project_one_remains(self): 1103 def test_remove_one_project_one_remains(
1080 manifest = self.getXmlManifest( 1104 self, repo_client: RepoClient
1105 ) -> None:
1106 """Check removal of one project while another remains."""
1107 manifest = repo_client.get_xml_manifest(
1081 """ 1108 """
1082<manifest> 1109<manifest>
1083 <remote name="default-remote" fetch="http://localhost" /> 1110 <remote name="default-remote" fetch="http://localhost" />
@@ -1089,24 +1116,30 @@ class RemoveProjectElementTests(ManifestParseTestCase):
1089""" 1116"""
1090 ) 1117 )
1091 1118
1092 self.assertEqual(len(manifest.projects), 1) 1119 assert len(manifest.projects) == 1
1093 self.assertEqual(manifest.projects[0].name, "yourproject") 1120 assert manifest.projects[0].name == "yourproject"
1094 1121
1095 def test_remove_one_project_doesnt_exist(self): 1122 def test_remove_one_project_doesnt_exist(
1096 with self.assertRaises(manifest_xml.ManifestParseError): 1123 self, repo_client: RepoClient
1097 manifest = self.getXmlManifest( 1124 ) -> None:
1098 """ 1125 """Check removal of non-existent project fails."""
1126 manifest = repo_client.get_xml_manifest(
1127 """
1099<manifest> 1128<manifest>
1100 <remote name="default-remote" fetch="http://localhost" /> 1129 <remote name="default-remote" fetch="http://localhost" />
1101 <default remote="default-remote" revision="refs/heads/main" /> 1130 <default remote="default-remote" revision="refs/heads/main" />
1102 <remove-project name="myproject" /> 1131 <remove-project name="myproject" />
1103</manifest> 1132</manifest>
1104""" 1133"""
1105 ) 1134 )
1135 with pytest.raises(error.ManifestParseError):
1106 manifest.projects 1136 manifest.projects
1107 1137
1108 def test_remove_one_optional_project_doesnt_exist(self): 1138 def test_remove_one_optional_project_doesnt_exist(
1109 manifest = self.getXmlManifest( 1139 self, repo_client: RepoClient
1140 ) -> None:
1141 """Check optional removal of non-existent project passes."""
1142 manifest = repo_client.get_xml_manifest(
1110 """ 1143 """
1111<manifest> 1144<manifest>
1112 <remote name="default-remote" fetch="http://localhost" /> 1145 <remote name="default-remote" fetch="http://localhost" />
@@ -1115,10 +1148,11 @@ class RemoveProjectElementTests(ManifestParseTestCase):
1115</manifest> 1148</manifest>
1116""" 1149"""
1117 ) 1150 )
1118 self.assertEqual(manifest.projects, []) 1151 assert manifest.projects == []
1119 1152
1120 def test_remove_using_path_attrib(self): 1153 def test_remove_using_path_attrib(self, repo_client: RepoClient) -> None:
1121 manifest = self.getXmlManifest( 1154 """Check removal using name and path attributes."""
1155 manifest = repo_client.get_xml_manifest(
1122 """ 1156 """
1123<manifest> 1157<manifest>
1124 <remote name="default-remote" fetch="http://localhost" /> 1158 <remote name="default-remote" fetch="http://localhost" />
@@ -1145,18 +1179,21 @@ class RemoveProjectElementTests(ManifestParseTestCase):
1145 for proj in manifest.projects: 1179 for proj in manifest.projects:
1146 if proj.name == "project1": 1180 if proj.name == "project1":
1147 found_proj1_path1 = True 1181 found_proj1_path1 = True
1148 self.assertEqual(proj.relpath, "tests/path1") 1182 assert proj.relpath == "tests/path1"
1149 if proj.name == "project2": 1183 if proj.name == "project2":
1150 found_proj2 = True 1184 found_proj2 = True
1151 self.assertNotEqual(proj.name, "project3") 1185 assert proj.name != "project3"
1152 self.assertNotEqual(proj.name, "project4") 1186 assert proj.name != "project4"
1153 self.assertNotEqual(proj.name, "project5") 1187 assert proj.name != "project5"
1154 self.assertNotEqual(proj.name, "project6") 1188 assert proj.name != "project6"
1155 self.assertTrue(found_proj1_path1) 1189 assert found_proj1_path1
1156 self.assertTrue(found_proj2) 1190 assert found_proj2
1157 1191
1158 def test_base_revision_checks_on_patching(self): 1192 def test_base_revision_checks_on_patching(
1159 manifest_fail_wrong_tag = self.getXmlManifest( 1193 self, repo_client: RepoClient
1194 ) -> None:
1195 """Check base-rev validation during patching."""
1196 manifest_fail_wrong_tag = repo_client.get_xml_manifest(
1160 """ 1197 """
1161<manifest> 1198<manifest>
1162 <remote name="default-remote" fetch="http://localhost" /> 1199 <remote name="default-remote" fetch="http://localhost" />
@@ -1166,10 +1203,10 @@ class RemoveProjectElementTests(ManifestParseTestCase):
1166</manifest> 1203</manifest>
1167""" 1204"""
1168 ) 1205 )
1169 with self.assertRaises(error.ManifestParseError): 1206 with pytest.raises(error.ManifestParseError):
1170 manifest_fail_wrong_tag.ToXml() 1207 manifest_fail_wrong_tag.ToXml()
1171 1208
1172 manifest_fail_remove = self.getXmlManifest( 1209 manifest_fail_remove = repo_client.get_xml_manifest(
1173 """ 1210 """
1174<manifest> 1211<manifest>
1175 <remote name="default-remote" fetch="http://localhost" /> 1212 <remote name="default-remote" fetch="http://localhost" />
@@ -1179,10 +1216,10 @@ class RemoveProjectElementTests(ManifestParseTestCase):
1179</manifest> 1216</manifest>
1180""" 1217"""
1181 ) 1218 )
1182 with self.assertRaises(error.ManifestParseError): 1219 with pytest.raises(error.ManifestParseError):
1183 manifest_fail_remove.ToXml() 1220 manifest_fail_remove.ToXml()
1184 1221
1185 manifest_fail_extend = self.getXmlManifest( 1222 manifest_fail_extend = repo_client.get_xml_manifest(
1186 """ 1223 """
1187<manifest> 1224<manifest>
1188 <remote name="default-remote" fetch="http://localhost" /> 1225 <remote name="default-remote" fetch="http://localhost" />
@@ -1192,10 +1229,10 @@ class RemoveProjectElementTests(ManifestParseTestCase):
1192</manifest> 1229</manifest>
1193""" 1230"""
1194 ) 1231 )
1195 with self.assertRaises(error.ManifestParseError): 1232 with pytest.raises(error.ManifestParseError):
1196 manifest_fail_extend.ToXml() 1233 manifest_fail_extend.ToXml()
1197 1234
1198 manifest_fail_unknown = self.getXmlManifest( 1235 manifest_fail_unknown = repo_client.get_xml_manifest(
1199 """ 1236 """
1200<manifest> 1237<manifest>
1201 <remote name="default-remote" fetch="http://localhost" /> 1238 <remote name="default-remote" fetch="http://localhost" />
@@ -1205,10 +1242,10 @@ class RemoveProjectElementTests(ManifestParseTestCase):
1205</manifest> 1242</manifest>
1206""" 1243"""
1207 ) 1244 )
1208 with self.assertRaises(error.ManifestParseError): 1245 with pytest.raises(error.ManifestParseError):
1209 manifest_fail_unknown.ToXml() 1246 manifest_fail_unknown.ToXml()
1210 1247
1211 manifest_ok = self.getXmlManifest( 1248 manifest_ok = repo_client.get_xml_manifest(
1212 """ 1249 """
1213<manifest> 1250<manifest>
1214 <remote name="default-remote" fetch="http://localhost" /> 1251 <remote name="default-remote" fetch="http://localhost" />
@@ -1234,18 +1271,21 @@ class RemoveProjectElementTests(ManifestParseTestCase):
1234 found_proj2 = True 1271 found_proj2 = True
1235 if proj.name == "project3": 1272 if proj.name == "project3":
1236 found_proj3 = True 1273 found_proj3 = True
1237 self.assertNotEqual(proj.name, "project1") 1274 assert proj.name != "project1"
1238 self.assertNotEqual(proj.name, "project4") 1275 assert proj.name != "project4"
1239 self.assertTrue(found_proj2) 1276 assert found_proj2
1240 self.assertTrue(found_proj3) 1277 assert found_proj3
1241 self.assertTrue(len(manifest_ok.projects) == 2) 1278 assert len(manifest_ok.projects) == 2
1242 1279
1243 1280
1244class ExtendProjectElementTests(ManifestParseTestCase): 1281class TestExtendProjectElement:
1245 """Tests for <extend-project>.""" 1282 """Tests for <extend-project>."""
1246 1283
1247 def test_extend_project_dest_path_single_match(self): 1284 def test_extend_project_dest_path_single_match(
1248 manifest = self.getXmlManifest( 1285 self, repo_client: RepoClient
1286 ) -> None:
1287 """Check dest-path when single match exists."""
1288 manifest = repo_client.get_xml_manifest(
1249 """ 1289 """
1250<manifest> 1290<manifest>
1251 <remote name="default-remote" fetch="http://localhost" /> 1291 <remote name="default-remote" fetch="http://localhost" />
@@ -1255,13 +1295,15 @@ class ExtendProjectElementTests(ManifestParseTestCase):
1255</manifest> 1295</manifest>
1256""" 1296"""
1257 ) 1297 )
1258 self.assertEqual(len(manifest.projects), 1) 1298 assert len(manifest.projects) == 1
1259 self.assertEqual(manifest.projects[0].relpath, "bar") 1299 assert manifest.projects[0].relpath == "bar"
1260 1300
1261 def test_extend_project_dest_path_multi_match(self): 1301 def test_extend_project_dest_path_multi_match(
1262 with self.assertRaises(manifest_xml.ManifestParseError): 1302 self, repo_client: RepoClient
1263 manifest = self.getXmlManifest( 1303 ) -> None:
1264 """ 1304 """Check dest-path when multiple matches exist fails."""
1305 manifest = repo_client.get_xml_manifest(
1306 """
1265<manifest> 1307<manifest>
1266 <remote name="default-remote" fetch="http://localhost" /> 1308 <remote name="default-remote" fetch="http://localhost" />
1267 <default remote="default-remote" revision="refs/heads/main" /> 1309 <default remote="default-remote" revision="refs/heads/main" />
@@ -1270,11 +1312,15 @@ class ExtendProjectElementTests(ManifestParseTestCase):
1270 <extend-project name="myproject" dest-path="bar" /> 1312 <extend-project name="myproject" dest-path="bar" />
1271</manifest> 1313</manifest>
1272""" 1314"""
1273 ) 1315 )
1316 with pytest.raises(error.ManifestParseError):
1274 manifest.projects 1317 manifest.projects
1275 1318
1276 def test_extend_project_dest_path_multi_match_path_specified(self): 1319 def test_extend_project_dest_path_multi_match_path_specified(
1277 manifest = self.getXmlManifest( 1320 self, repo_client: RepoClient
1321 ) -> None:
1322 """Check dest-path when path is specified for multi-match."""
1323 manifest = repo_client.get_xml_manifest(
1278 """ 1324 """
1279<manifest> 1325<manifest>
1280 <remote name="default-remote" fetch="http://localhost" /> 1326 <remote name="default-remote" fetch="http://localhost" />
@@ -1285,29 +1331,32 @@ class ExtendProjectElementTests(ManifestParseTestCase):
1285</manifest> 1331</manifest>
1286""" 1332"""
1287 ) 1333 )
1288 self.assertEqual(len(manifest.projects), 2) 1334 assert len(manifest.projects) == 2
1289 if manifest.projects[0].relpath == "y": 1335 if manifest.projects[0].relpath == "y":
1290 self.assertEqual(manifest.projects[1].relpath, "bar") 1336 assert manifest.projects[1].relpath == "bar"
1291 else: 1337 else:
1292 self.assertEqual(manifest.projects[0].relpath, "bar") 1338 assert manifest.projects[0].relpath == "bar"
1293 self.assertEqual(manifest.projects[1].relpath, "y") 1339 assert manifest.projects[1].relpath == "y"
1294 1340
1295 def test_extend_project_dest_branch(self): 1341 def test_extend_project_dest_branch(self, repo_client: RepoClient) -> None:
1296 manifest = self.getXmlManifest( 1342 """Check dest-branch update via extend-project."""
1343 manifest = repo_client.get_xml_manifest(
1297 """ 1344 """
1298<manifest> 1345<manifest>
1299 <remote name="default-remote" fetch="http://localhost" /> 1346 <remote name="default-remote" fetch="http://localhost" />
1300 <default remote="default-remote" revision="refs/heads/main" dest-branch="foo" /> 1347 <default remote="default-remote" revision="refs/heads/main"
1348 dest-branch="foo" />
1301 <project name="myproject" /> 1349 <project name="myproject" />
1302 <extend-project name="myproject" dest-branch="bar" /> 1350 <extend-project name="myproject" dest-branch="bar" />
1303</manifest> 1351</manifest>
1304""" # noqa: E501 1352"""
1305 ) 1353 )
1306 self.assertEqual(len(manifest.projects), 1) 1354 assert len(manifest.projects) == 1
1307 self.assertEqual(manifest.projects[0].dest_branch, "bar") 1355 assert manifest.projects[0].dest_branch == "bar"
1308 1356
1309 def test_extend_project_upstream(self): 1357 def test_extend_project_upstream(self, repo_client: RepoClient) -> None:
1310 manifest = self.getXmlManifest( 1358 """Check upstream update via extend-project."""
1359 manifest = repo_client.get_xml_manifest(
1311 """ 1360 """
1312<manifest> 1361<manifest>
1313 <remote name="default-remote" fetch="http://localhost" /> 1362 <remote name="default-remote" fetch="http://localhost" />
@@ -1317,11 +1366,12 @@ class ExtendProjectElementTests(ManifestParseTestCase):
1317</manifest> 1366</manifest>
1318""" 1367"""
1319 ) 1368 )
1320 self.assertEqual(len(manifest.projects), 1) 1369 assert len(manifest.projects) == 1
1321 self.assertEqual(manifest.projects[0].upstream, "bar") 1370 assert manifest.projects[0].upstream == "bar"
1322 1371
1323 def test_extend_project_copyfiles(self): 1372 def test_extend_project_copyfiles(self, repo_client: RepoClient) -> None:
1324 manifest = self.getXmlManifest( 1373 """Check copyfile addition via extend-project."""
1374 manifest = repo_client.get_xml_manifest(
1325 """ 1375 """
1326<manifest> 1376<manifest>
1327 <remote name="default-remote" fetch="http://localhost" /> 1377 <remote name="default-remote" fetch="http://localhost" />
@@ -1333,21 +1383,24 @@ class ExtendProjectElementTests(ManifestParseTestCase):
1333</manifest> 1383</manifest>
1334""" 1384"""
1335 ) 1385 )
1336 self.assertEqual(list(manifest.projects[0].copyfiles)[0].src, "foo") 1386 assert list(manifest.projects[0].copyfiles)[0].src == "foo"
1337 self.assertEqual(list(manifest.projects[0].copyfiles)[0].dest, "bar") 1387 assert list(manifest.projects[0].copyfiles)[0].dest == "bar"
1338 self.assertEqual( 1388 assert (
1339 sort_attributes(manifest.ToXml().toxml()), 1389 sort_attributes(manifest.ToXml().toxml())
1340 '<?xml version="1.0" ?><manifest>' 1390 == '<?xml version="1.0" ?><manifest>'
1341 '<remote fetch="http://localhost" name="default-remote"/>' 1391 '<remote fetch="http://localhost" name="default-remote"/>'
1342 '<default remote="default-remote" revision="refs/heads/main"/>' 1392 '<default remote="default-remote" revision="refs/heads/main"/>'
1343 '<project name="myproject">' 1393 '<project name="myproject">'
1344 '<copyfile dest="bar" src="foo"/>' 1394 '<copyfile dest="bar" src="foo"/>'
1345 "</project>" 1395 "</project>"
1346 "</manifest>", 1396 "</manifest>"
1347 ) 1397 )
1348 1398
1349 def test_extend_project_duplicate_copyfiles(self): 1399 def test_extend_project_duplicate_copyfiles(
1350 root_m = self.manifest_dir / "root.xml" 1400 self, repo_client: RepoClient
1401 ) -> None:
1402 """Check duplicate copyfile handling in includes."""
1403 root_m = repo_client.manifest_dir / "root.xml"
1351 root_m.write_text( 1404 root_m.write_text(
1352 """ 1405 """
1353<manifest> 1406<manifest>
@@ -1359,21 +1412,21 @@ class ExtendProjectElementTests(ManifestParseTestCase):
1359</manifest> 1412</manifest>
1360""" 1413"""
1361 ) 1414 )
1362 (self.manifest_dir / "man1.xml").write_text( 1415 (repo_client.manifest_dir / "man1.xml").write_text(
1363 """ 1416 """
1364<manifest> 1417<manifest>
1365 <include name="common.xml" /> 1418 <include name="common.xml" />
1366</manifest> 1419</manifest>
1367""" 1420"""
1368 ) 1421 )
1369 (self.manifest_dir / "man2.xml").write_text( 1422 (repo_client.manifest_dir / "man2.xml").write_text(
1370 """ 1423 """
1371<manifest> 1424<manifest>
1372 <include name="common.xml" /> 1425 <include name="common.xml" />
1373</manifest> 1426</manifest>
1374""" 1427"""
1375 ) 1428 )
1376 (self.manifest_dir / "common.xml").write_text( 1429 (repo_client.manifest_dir / "common.xml").write_text(
1377 """ 1430 """
1378<manifest> 1431<manifest>
1379 <extend-project name="myproject"> 1432 <extend-project name="myproject">
@@ -1382,13 +1435,16 @@ class ExtendProjectElementTests(ManifestParseTestCase):
1382</manifest> 1435</manifest>
1383""" 1436"""
1384 ) 1437 )
1385 manifest = manifest_xml.XmlManifest(str(self.repodir), str(root_m)) 1438 manifest = manifest_xml.XmlManifest(
1386 self.assertEqual(len(manifest.projects[0].copyfiles), 1) 1439 str(repo_client.repodir), str(root_m)
1387 self.assertEqual(list(manifest.projects[0].copyfiles)[0].src, "foo") 1440 )
1388 self.assertEqual(list(manifest.projects[0].copyfiles)[0].dest, "bar") 1441 assert len(manifest.projects[0].copyfiles) == 1
1442 assert list(manifest.projects[0].copyfiles)[0].src == "foo"
1443 assert list(manifest.projects[0].copyfiles)[0].dest == "bar"
1389 1444
1390 def test_extend_project_linkfiles(self): 1445 def test_extend_project_linkfiles(self, repo_client: RepoClient) -> None:
1391 manifest = self.getXmlManifest( 1446 """Check linkfile addition via extend-project."""
1447 manifest = repo_client.get_xml_manifest(
1392 """ 1448 """
1393<manifest> 1449<manifest>
1394 <remote name="default-remote" fetch="http://localhost" /> 1450 <remote name="default-remote" fetch="http://localhost" />
@@ -1400,21 +1456,24 @@ class ExtendProjectElementTests(ManifestParseTestCase):
1400</manifest> 1456</manifest>
1401""" 1457"""
1402 ) 1458 )
1403 self.assertEqual(list(manifest.projects[0].linkfiles)[0].src, "foo") 1459 assert list(manifest.projects[0].linkfiles)[0].src == "foo"
1404 self.assertEqual(list(manifest.projects[0].linkfiles)[0].dest, "bar") 1460 assert list(manifest.projects[0].linkfiles)[0].dest == "bar"
1405 self.assertEqual( 1461 assert (
1406 sort_attributes(manifest.ToXml().toxml()), 1462 sort_attributes(manifest.ToXml().toxml())
1407 '<?xml version="1.0" ?><manifest>' 1463 == '<?xml version="1.0" ?><manifest>'
1408 '<remote fetch="http://localhost" name="default-remote"/>' 1464 '<remote fetch="http://localhost" name="default-remote"/>'
1409 '<default remote="default-remote" revision="refs/heads/main"/>' 1465 '<default remote="default-remote" revision="refs/heads/main"/>'
1410 '<project name="myproject">' 1466 '<project name="myproject">'
1411 '<linkfile dest="bar" src="foo"/>' 1467 '<linkfile dest="bar" src="foo"/>'
1412 "</project>" 1468 "</project>"
1413 "</manifest>", 1469 "</manifest>"
1414 ) 1470 )
1415 1471
1416 def test_extend_project_duplicate_linkfiles(self): 1472 def test_extend_project_duplicate_linkfiles(
1417 root_m = self.manifest_dir / "root.xml" 1473 self, repo_client: RepoClient
1474 ) -> None:
1475 """Check duplicate linkfile handling in includes."""
1476 root_m = repo_client.manifest_dir / "root.xml"
1418 root_m.write_text( 1477 root_m.write_text(
1419 """ 1478 """
1420<manifest> 1479<manifest>
@@ -1426,21 +1485,21 @@ class ExtendProjectElementTests(ManifestParseTestCase):
1426</manifest> 1485</manifest>
1427""" 1486"""
1428 ) 1487 )
1429 (self.manifest_dir / "man1.xml").write_text( 1488 (repo_client.manifest_dir / "man1.xml").write_text(
1430 """ 1489 """
1431<manifest> 1490<manifest>
1432 <include name="common.xml" /> 1491 <include name="common.xml" />
1433</manifest> 1492</manifest>
1434""" 1493"""
1435 ) 1494 )
1436 (self.manifest_dir / "man2.xml").write_text( 1495 (repo_client.manifest_dir / "man2.xml").write_text(
1437 """ 1496 """
1438<manifest> 1497<manifest>
1439 <include name="common.xml" /> 1498 <include name="common.xml" />
1440</manifest> 1499</manifest>
1441""" 1500"""
1442 ) 1501 )
1443 (self.manifest_dir / "common.xml").write_text( 1502 (repo_client.manifest_dir / "common.xml").write_text(
1444 """ 1503 """
1445<manifest> 1504<manifest>
1446 <extend-project name="myproject"> 1505 <extend-project name="myproject">
@@ -1449,13 +1508,16 @@ class ExtendProjectElementTests(ManifestParseTestCase):
1449</manifest> 1508</manifest>
1450""" 1509"""
1451 ) 1510 )
1452 manifest = manifest_xml.XmlManifest(str(self.repodir), str(root_m)) 1511 manifest = manifest_xml.XmlManifest(
1453 self.assertEqual(len(manifest.projects[0].linkfiles), 1) 1512 str(repo_client.repodir), str(root_m)
1454 self.assertEqual(list(manifest.projects[0].linkfiles)[0].src, "foo") 1513 )
1455 self.assertEqual(list(manifest.projects[0].linkfiles)[0].dest, "bar") 1514 assert len(manifest.projects[0].linkfiles) == 1
1515 assert list(manifest.projects[0].linkfiles)[0].src == "foo"
1516 assert list(manifest.projects[0].linkfiles)[0].dest == "bar"
1456 1517
1457 def test_extend_project_annotations(self): 1518 def test_extend_project_annotations(self, repo_client: RepoClient) -> None:
1458 manifest = self.getXmlManifest( 1519 """Check annotation addition via extend-project."""
1520 manifest = repo_client.get_xml_manifest(
1459 """ 1521 """
1460<manifest> 1522<manifest>
1461 <remote name="default-remote" fetch="http://localhost" /> 1523 <remote name="default-remote" fetch="http://localhost" />
@@ -1467,21 +1529,24 @@ class ExtendProjectElementTests(ManifestParseTestCase):
1467</manifest> 1529</manifest>
1468""" 1530"""
1469 ) 1531 )
1470 self.assertEqual(manifest.projects[0].annotations[0].name, "foo") 1532 assert manifest.projects[0].annotations[0].name == "foo"
1471 self.assertEqual(manifest.projects[0].annotations[0].value, "bar") 1533 assert manifest.projects[0].annotations[0].value == "bar"
1472 self.assertEqual( 1534 assert (
1473 sort_attributes(manifest.ToXml().toxml()), 1535 sort_attributes(manifest.ToXml().toxml())
1474 '<?xml version="1.0" ?><manifest>' 1536 == '<?xml version="1.0" ?><manifest>'
1475 '<remote fetch="http://localhost" name="default-remote"/>' 1537 '<remote fetch="http://localhost" name="default-remote"/>'
1476 '<default remote="default-remote" revision="refs/heads/main"/>' 1538 '<default remote="default-remote" revision="refs/heads/main"/>'
1477 '<project name="myproject">' 1539 '<project name="myproject">'
1478 '<annotation name="foo" value="bar"/>' 1540 '<annotation name="foo" value="bar"/>'
1479 "</project>" 1541 "</project>"
1480 "</manifest>", 1542 "</manifest>"
1481 ) 1543 )
1482 1544
1483 def test_extend_project_annotations_multiples(self): 1545 def test_extend_project_annotations_multiples(
1484 manifest = self.getXmlManifest( 1546 self, repo_client: RepoClient
1547 ) -> None:
1548 """Check multiple annotation additions via extend-project."""
1549 manifest = repo_client.get_xml_manifest(
1485 """ 1550 """
1486<manifest> 1551<manifest>
1487 <remote name="default-remote" fetch="http://localhost" /> 1552 <remote name="default-remote" fetch="http://localhost" />
@@ -1497,18 +1562,17 @@ class ExtendProjectElementTests(ManifestParseTestCase):
1497</manifest> 1562</manifest>
1498""" 1563"""
1499 ) 1564 )
1500 self.assertEqual( 1565 assert [
1501 [(a.name, a.value) for a in manifest.projects[0].annotations], 1566 (a.name, a.value) for a in manifest.projects[0].annotations
1502 [ 1567 ] == [
1503 ("foo", "bar"), 1568 ("foo", "bar"),
1504 ("few", "bar"), 1569 ("few", "bar"),
1505 ("foo", "new_bar"), 1570 ("foo", "new_bar"),
1506 ("new", "anno"), 1571 ("new", "anno"),
1507 ], 1572 ]
1508 ) 1573 assert (
1509 self.assertEqual( 1574 sort_attributes(manifest.ToXml().toxml())
1510 sort_attributes(manifest.ToXml().toxml()), 1575 == '<?xml version="1.0" ?><manifest>'
1511 '<?xml version="1.0" ?><manifest>'
1512 '<remote fetch="http://localhost" name="default-remote"/>' 1576 '<remote fetch="http://localhost" name="default-remote"/>'
1513 '<default remote="default-remote" revision="refs/heads/main"/>' 1577 '<default remote="default-remote" revision="refs/heads/main"/>'
1514 '<project name="myproject">' 1578 '<project name="myproject">'
@@ -1517,81 +1581,78 @@ class ExtendProjectElementTests(ManifestParseTestCase):
1517 '<annotation name="foo" value="new_bar"/>' 1581 '<annotation name="foo" value="new_bar"/>'
1518 '<annotation name="new" value="anno"/>' 1582 '<annotation name="new" value="anno"/>'
1519 "</project>" 1583 "</project>"
1520 "</manifest>", 1584 "</manifest>"
1521 ) 1585 )
1522 1586
1523 1587
1524class NormalizeUrlTests(ManifestParseTestCase): 1588class TestNormalizeUrl:
1525 """Tests for normalize_url() in manifest_xml.py""" 1589 """Tests for normalize_url() in manifest_xml.py"""
1526 1590
1527 def test_has_trailing_slash(self): 1591 def test_has_trailing_slash(self) -> None:
1592 """Trailing slashes should be removed."""
1528 url = "http://foo.com/bar/baz/" 1593 url = "http://foo.com/bar/baz/"
1529 self.assertEqual( 1594 assert manifest_xml.normalize_url(url) == "http://foo.com/bar/baz"
1530 "http://foo.com/bar/baz", manifest_xml.normalize_url(url)
1531 )
1532 1595
1533 url = "http://foo.com/bar/" 1596 url = "http://foo.com/bar/"
1534 self.assertEqual("http://foo.com/bar", manifest_xml.normalize_url(url)) 1597 assert manifest_xml.normalize_url(url) == "http://foo.com/bar"
1535 1598
1536 def test_has_leading_slash(self): 1599 def test_has_leading_slash(self) -> None:
1537 """SCP-like syntax except a / comes before the : which git disallows.""" 1600 """SCP-like syntax except a / comes before the : which git disallows."""
1538 url = "/git@foo.com:bar/baf" 1601 url = "/git@foo.com:bar/baf"
1539 self.assertEqual(url, manifest_xml.normalize_url(url)) 1602 assert manifest_xml.normalize_url(url) == url
1540 1603
1541 url = "gi/t@foo.com:bar/baf" 1604 url = "gi/t@foo.com:bar/baf"
1542 self.assertEqual(url, manifest_xml.normalize_url(url)) 1605 assert manifest_xml.normalize_url(url) == url
1543 1606
1544 url = "git@fo/o.com:bar/baf" 1607 url = "git@fo/o.com:bar/baf"
1545 self.assertEqual(url, manifest_xml.normalize_url(url)) 1608 assert manifest_xml.normalize_url(url) == url
1546 1609
1547 def test_has_no_scheme(self): 1610 def test_has_no_scheme(self) -> None:
1548 """Deal with cases where we have no scheme, but we also 1611 """Deal with cases where we have no scheme, but we also
1549 aren't dealing with the git SCP-like syntax 1612 aren't dealing with the git SCP-like syntax
1550 """ 1613 """
1551 url = "foo.com/baf/bat" 1614 url = "foo.com/baf/bat"
1552 self.assertEqual(url, manifest_xml.normalize_url(url)) 1615 assert manifest_xml.normalize_url(url) == url
1553 1616
1554 url = "foo.com/baf" 1617 url = "foo.com/baf"
1555 self.assertEqual(url, manifest_xml.normalize_url(url)) 1618 assert manifest_xml.normalize_url(url) == url
1556 1619
1557 url = "git@foo.com/baf/bat" 1620 url = "git@foo.com/baf/bat"
1558 self.assertEqual(url, manifest_xml.normalize_url(url)) 1621 assert manifest_xml.normalize_url(url) == url
1559 1622
1560 url = "git@foo.com/baf" 1623 url = "git@foo.com/baf"
1561 self.assertEqual(url, manifest_xml.normalize_url(url)) 1624 assert manifest_xml.normalize_url(url) == url
1562 1625
1563 url = "/file/path/here" 1626 url = "/file/path/here"
1564 self.assertEqual(url, manifest_xml.normalize_url(url)) 1627 assert manifest_xml.normalize_url(url) == url
1565 1628
1566 def test_has_no_scheme_matches_scp_like_syntax(self): 1629 def test_has_no_scheme_matches_scp_like_syntax(self) -> None:
1630 """SCP-like syntax should be converted to ssh://."""
1567 url = "git@foo.com:bar/baf" 1631 url = "git@foo.com:bar/baf"
1568 self.assertEqual( 1632 assert manifest_xml.normalize_url(url) == "ssh://git@foo.com/bar/baf"
1569 "ssh://git@foo.com/bar/baf", manifest_xml.normalize_url(url)
1570 )
1571 1633
1572 url = "git@foo.com:bar/" 1634 url = "git@foo.com:bar/"
1573 self.assertEqual( 1635 assert manifest_xml.normalize_url(url) == "ssh://git@foo.com/bar"
1574 "ssh://git@foo.com/bar", manifest_xml.normalize_url(url)
1575 )
1576 1636
1577 def test_remote_url_resolution(self): 1637 def test_remote_url_resolution(self) -> None:
1638 """Check resolvedFetchUrl calculation."""
1578 remote = manifest_xml._XmlRemote( 1639 remote = manifest_xml._XmlRemote(
1579 name="foo", 1640 name="foo",
1580 fetch="git@github.com:org2/", 1641 fetch="git@github.com:org2/",
1581 manifestUrl="git@github.com:org2/custom_manifest.git", 1642 manifestUrl="git@github.com:org2/custom_manifest.git",
1582 ) 1643 )
1583 self.assertEqual("ssh://git@github.com/org2", remote.resolvedFetchUrl) 1644 assert remote.resolvedFetchUrl == "ssh://git@github.com/org2"
1584 1645
1585 remote = manifest_xml._XmlRemote( 1646 remote = manifest_xml._XmlRemote(
1586 name="foo", 1647 name="foo",
1587 fetch="ssh://git@github.com/org2/", 1648 fetch="ssh://git@github.com/org2/",
1588 manifestUrl="git@github.com:org2/custom_manifest.git", 1649 manifestUrl="git@github.com:org2/custom_manifest.git",
1589 ) 1650 )
1590 self.assertEqual("ssh://git@github.com/org2", remote.resolvedFetchUrl) 1651 assert remote.resolvedFetchUrl == "ssh://git@github.com/org2"
1591 1652
1592 remote = manifest_xml._XmlRemote( 1653 remote = manifest_xml._XmlRemote(
1593 name="foo", 1654 name="foo",
1594 fetch="git@github.com:org2/", 1655 fetch="git@github.com:org2/",
1595 manifestUrl="ssh://git@github.com/org2/custom_manifest.git", 1656 manifestUrl="ssh://git@github.com/org2/custom_manifest.git",
1596 ) 1657 )
1597 self.assertEqual("ssh://git@github.com/org2", remote.resolvedFetchUrl) 1658 assert remote.resolvedFetchUrl == "ssh://git@github.com/org2"
diff --git a/tests/test_project.py b/tests/test_project.py
index 501707eaf..6a04ea455 100644
--- a/tests/test_project.py
+++ b/tests/test_project.py
@@ -17,10 +17,12 @@
17import contextlib 17import contextlib
18import os 18import os
19from pathlib import Path 19from pathlib import Path
20import shutil
20import subprocess 21import subprocess
21import tempfile 22import tempfile
22from typing import Optional 23from typing import Optional
23import unittest 24import unittest
25from unittest import mock
24 26
25import utils_for_test 27import utils_for_test
26 28
@@ -565,3 +567,230 @@ class ManifestPropertiesFetchedCorrectly(unittest.TestCase):
565 567
566 fakeproj.config.SetString("manifest.platform", "auto") 568 fakeproj.config.SetString("manifest.platform", "auto")
567 self.assertEqual(fakeproj.manifest_platform, "auto") 569 self.assertEqual(fakeproj.manifest_platform, "auto")
570
571
572class StatelessSyncTests(unittest.TestCase):
573 """Tests for stateless sync strategy."""
574
575 def _get_project(self, tempdir):
576 manifest = mock.MagicMock()
577 manifest.manifestProject.depth = None
578 manifest.manifestProject.dissociate = False
579 manifest.manifestProject.clone_filter = None
580 manifest.is_multimanifest = False
581 manifest.manifestProject.config.GetBoolean.return_value = False
582
583 remote = mock.MagicMock()
584 remote.name = "origin"
585 remote.url = "http://"
586
587 proj = project.Project(
588 manifest=manifest,
589 name="test-project",
590 remote=remote,
591 gitdir=os.path.join(tempdir, ".git"),
592 objdir=os.path.join(tempdir, ".git"),
593 worktree=tempdir,
594 relpath="test-project",
595 revisionExpr="1234abcd",
596 revisionId=None,
597 sync_strategy="stateless",
598 )
599 proj._CheckForImmutableRevision = mock.MagicMock(return_value=False)
600 proj._LsRemote = mock.MagicMock(
601 return_value="1234abcd\trefs/heads/main\n"
602 )
603 proj.bare_git = mock.MagicMock()
604 proj.bare_git.rev_parse.return_value = "5678abcd"
605 proj.bare_git.rev_list.return_value = ["0"]
606 proj.IsDirty = mock.MagicMock(return_value=False)
607 proj.GetBranches = mock.MagicMock(return_value=[])
608 proj.DeleteWorktree = mock.MagicMock()
609 proj._InitGitDir = mock.MagicMock()
610 proj._RemoteFetch = mock.MagicMock(return_value=True)
611 proj._InitRemote = mock.MagicMock()
612 proj._InitMRef = mock.MagicMock()
613 return proj
614
615 def test_sync_network_half_stateless_prune_needed(self):
616 """Test stateless sync queues prune when needed."""
617 with utils_for_test.TempGitTree() as tempdir:
618 proj = self._get_project(tempdir)
619 res = proj.Sync_NetworkHalf()
620
621 self.assertTrue(res.success)
622 proj.DeleteWorktree.assert_not_called()
623 self.assertTrue(proj.stateless_prune_needed)
624 proj._RemoteFetch.assert_called_once()
625
626 def test_sync_local_half_stateless_prune(self):
627 """Test stateless GC pruning is queued in Sync_LocalHalf."""
628 with utils_for_test.TempGitTree() as tempdir:
629 proj = self._get_project(tempdir)
630 proj.stateless_prune_needed = True
631
632 proj._Checkout = mock.MagicMock()
633 proj._InitWorkTree = mock.MagicMock()
634 proj.IsRebaseInProgress = mock.MagicMock(return_value=False)
635 proj.IsCherryPickInProgress = mock.MagicMock(return_value=False)
636 proj.bare_ref = mock.MagicMock()
637 proj.bare_ref.all = {}
638 proj.GetRevisionId = mock.MagicMock(return_value="1234abcd")
639 proj._CopyAndLinkFiles = mock.MagicMock()
640
641 proj.work_git = mock.MagicMock()
642 proj.work_git.GetHead.return_value = "5678abcd"
643
644 syncbuf = project.SyncBuffer(proj.config)
645
646 with mock.patch("project.GitCommand") as mock_git_cmd:
647 mock_cmd_instance = mock.MagicMock()
648 mock_cmd_instance.Wait.return_value = 0
649 mock_git_cmd.return_value = mock_cmd_instance
650
651 proj.Sync_LocalHalf(syncbuf)
652 syncbuf.Finish()
653
654 self.assertEqual(mock_git_cmd.call_count, 2)
655 mock_git_cmd.assert_any_call(
656 proj, ["reflog", "expire", "--expire=all", "--all"], bare=True
657 )
658 mock_git_cmd.assert_any_call(
659 proj,
660 ["gc", "--prune=now"],
661 bare=True,
662 capture_stdout=True,
663 capture_stderr=True,
664 )
665
666 def test_sync_network_half_stateless_skips_if_stash(self):
667 """Test stateless sync skips if stash exists."""
668 with utils_for_test.TempGitTree() as tempdir:
669 proj = self._get_project(tempdir)
670 proj.HasStash = mock.MagicMock(return_value=True)
671
672 res = proj.Sync_NetworkHalf()
673
674 self.assertTrue(res.success)
675 self.assertFalse(getattr(proj, "stateless_prune_needed", False))
676
677 def test_sync_network_half_stateless_skips_if_local_commits(self):
678 """Test stateless sync skips if there are local-only commits."""
679 with utils_for_test.TempGitTree() as tempdir:
680 proj = self._get_project(tempdir)
681 proj.bare_git.rev_list.return_value = ["1"]
682
683 res = proj.Sync_NetworkHalf()
684
685 self.assertTrue(res.success)
686 self.assertFalse(getattr(proj, "stateless_prune_needed", False))
687
688
689class SyncOptimizationTests(unittest.TestCase):
690 """Tests for sync optimization logic involving shallow clones."""
691
692 def _get_project(self, tempdir, depth=None):
693 manifest = mock.MagicMock()
694 manifest.manifestProject.depth = depth
695 manifest.manifestProject.dissociate = False
696 manifest.manifestProject.clone_filter = None
697 manifest.is_multimanifest = False
698 manifest.manifestProject.config.GetBoolean.return_value = False
699 manifest.IsMirror = False
700
701 remote = mock.MagicMock()
702 remote.name = "origin"
703 remote.url = "http://"
704
705 proj = project.Project(
706 manifest=manifest,
707 name="test-project",
708 remote=remote,
709 gitdir=os.path.join(tempdir, "gitdir"),
710 objdir=os.path.join(tempdir, "objdir"),
711 worktree=tempdir,
712 relpath="test-project",
713 revisionExpr="0123456789abcdef0123456789abcdef01234567",
714 revisionId=None,
715 )
716 proj._CheckForImmutableRevision = mock.MagicMock(return_value=True)
717 proj.DeleteWorktree = mock.MagicMock()
718 proj._InitGitDir = mock.MagicMock()
719 proj._InitRemote = mock.MagicMock()
720 proj._InitMRef = mock.MagicMock()
721 return proj
722
723 def test_sync_network_half_shallow_missing_fetches(self):
724 """Test Sync_NetworkHalf fetches if shallow file is missing."""
725 with utils_for_test.TempGitTree() as tempdir:
726 proj = self._get_project(tempdir, depth=1)
727 # Ensure gitdir does not exist to simulate new project
728 if os.path.exists(proj.gitdir):
729 shutil.rmtree(proj.gitdir)
730 shallow_path = os.path.join(proj.gitdir, "shallow")
731 if os.path.exists(shallow_path):
732 os.unlink(shallow_path)
733
734 proj._RemoteFetch = mock.MagicMock(return_value=True)
735
736 res = proj.Sync_NetworkHalf(optimized_fetch=True)
737
738 self.assertTrue(res.success)
739 proj._RemoteFetch.assert_called_once()
740
741 def test_sync_network_half_shallow_exists_skips(self):
742 """Test Sync_NetworkHalf skips fetch if shallow file exists."""
743 with utils_for_test.TempGitTree() as tempdir:
744 proj = self._get_project(tempdir, depth=1)
745 os.makedirs(proj.gitdir, exist_ok=True)
746 os.makedirs(proj.objdir, exist_ok=True)
747 with open(os.path.join(proj.gitdir, "shallow"), "w") as f:
748 f.write("")
749
750 proj._RemoteFetch = mock.MagicMock()
751
752 res = proj.Sync_NetworkHalf(optimized_fetch=True)
753
754 self.assertTrue(res.success)
755 proj._RemoteFetch.assert_not_called()
756
757 def test_remote_fetch_shallow_missing_fetches(self):
758 """Test _RemoteFetch fetches if shallow file is missing."""
759 with utils_for_test.TempGitTree() as tempdir:
760 proj = self._get_project(tempdir, depth=1)
761 shallow_path = os.path.join(proj.gitdir, "shallow")
762 if os.path.exists(shallow_path):
763 os.unlink(shallow_path)
764
765 with mock.patch("project.GitCommand") as mock_git_cmd:
766 mock_cmd_instance = mock.MagicMock()
767 mock_cmd_instance.Wait.return_value = 0
768 mock_git_cmd.return_value = mock_cmd_instance
769
770 res = proj._RemoteFetch(
771 current_branch_only=True,
772 depth=1,
773 use_superproject=False,
774 )
775
776 self.assertTrue(res)
777 mock_git_cmd.assert_called()
778
779 def test_remote_fetch_shallow_exists_skips(self):
780 """Test _RemoteFetch skips fetch if shallow file exists."""
781 with utils_for_test.TempGitTree() as tempdir:
782 proj = self._get_project(tempdir, depth=1)
783 os.makedirs(proj.gitdir, exist_ok=True)
784 os.makedirs(proj.objdir, exist_ok=True)
785 with open(os.path.join(proj.gitdir, "shallow"), "w") as f:
786 f.write("")
787
788 with mock.patch("project.GitCommand") as mock_git_cmd:
789 res = proj._RemoteFetch(
790 current_branch_only=True,
791 depth=1,
792 use_superproject=False,
793 )
794
795 self.assertTrue(res)
796 mock_git_cmd.assert_not_called()
diff --git a/tests/test_subcmds_forall.py b/tests/test_subcmds_forall.py
index e50b28d8a..e36c8b660 100644
--- a/tests/test_subcmds_forall.py
+++ b/tests/test_subcmds_forall.py
@@ -14,11 +14,9 @@
14 14
15"""Unittests for the forall subcmd.""" 15"""Unittests for the forall subcmd."""
16 16
17from io import StringIO 17import contextlib
18import os 18import io
19from shutil import rmtree 19from pathlib import Path
20import tempfile
21import unittest
22from unittest import mock 20from unittest import mock
23 21
24import utils_for_test 22import utils_for_test
@@ -28,111 +26,81 @@ import project
28import subcmds 26import subcmds
29 27
30 28
31class AllCommands(unittest.TestCase): 29def _create_manifest_with_8_projects(
32 """Check registered all_commands.""" 30 topdir: Path,
33 31) -> manifest_xml.XmlManifest:
34 def setUp(self): 32 """Create a setup of 8 projects to execute forall."""
35 """Common setup.""" 33 repodir = topdir / ".repo"
36 self.tempdirobj = tempfile.TemporaryDirectory(prefix="forall_tests") 34 manifest_dir = repodir / "manifests"
37 self.tempdir = self.tempdirobj.name 35 manifest_file = repodir / manifest_xml.MANIFEST_FILE_NAME
38 self.repodir = os.path.join(self.tempdir, ".repo") 36
39 self.manifest_dir = os.path.join(self.repodir, "manifests") 37 repodir.mkdir()
40 self.manifest_file = os.path.join( 38 manifest_dir.mkdir()
41 self.repodir, manifest_xml.MANIFEST_FILE_NAME 39
42 ) 40 # Set up a manifest git dir for parsing to work.
43 self.local_manifest_dir = os.path.join( 41 gitdir = repodir / "manifests.git"
44 self.repodir, manifest_xml.LOCAL_MANIFESTS_DIR_NAME 42 gitdir.mkdir()
45 ) 43 (gitdir / "config").write_text(
46 os.mkdir(self.repodir) 44 """[remote "origin"]
47 os.mkdir(self.manifest_dir) 45 url = https://localhost:0/manifest
48 46 verbose = false
49 def tearDown(self): 47 """
50 """Common teardown.""" 48 )
51 rmtree(self.tempdir, ignore_errors=True) 49
52 50 # Add the manifest data.
53 def getXmlManifestWith8Projects(self): 51 manifest_file.write_text(
54 """Create and return a setup of 8 projects with enough dummy 52 """
55 files and setup to execute forall.""" 53 <manifest>
56 54 <remote name="origin" fetch="http://localhost" />
57 # Set up a manifest git dir for parsing to work 55 <default remote="origin" revision="refs/heads/main" />
58 gitdir = os.path.join(self.repodir, "manifests.git") 56 <project name="project1" path="tests/path1" />
59 os.mkdir(gitdir) 57 <project name="project2" path="tests/path2" />
60 with open(os.path.join(gitdir, "config"), "w") as fp: 58 <project name="project3" path="tests/path3" />
61 fp.write( 59 <project name="project4" path="tests/path4" />
62 """[remote "origin"] 60 <project name="project5" path="tests/path5" />
63 url = https://localhost:0/manifest 61 <project name="project6" path="tests/path6" />
64 verbose = false 62 <project name="project7" path="tests/path7" />
65 """ 63 <project name="project8" path="tests/path8" />
66 ) 64 </manifest>
67 65 """,
68 # Add the manifest data 66 encoding="utf-8",
69 manifest_data = """ 67 )
70 <manifest> 68
71 <remote name="origin" fetch="http://localhost" /> 69 # Set up 8 empty projects to match the manifest.
72 <default remote="origin" revision="refs/heads/main" /> 70 for x in range(1, 9):
73 <project name="project1" path="tests/path1" /> 71 (repodir / "projects" / "tests" / f"path{x}.git").mkdir(parents=True)
74 <project name="project2" path="tests/path2" /> 72 (repodir / "project-objects" / f"project{x}.git").mkdir(parents=True)
75 <project name="project3" path="tests/path3" /> 73 git_path = topdir / "tests" / f"path{x}"
76 <project name="project4" path="tests/path4" /> 74 utils_for_test.init_git_tree(git_path)
77 <project name="project5" path="tests/path5" /> 75
78 <project name="project6" path="tests/path6" /> 76 return manifest_xml.XmlManifest(str(repodir), str(manifest_file))
79 <project name="project7" path="tests/path7" /> 77
80 <project name="project8" path="tests/path8" /> 78
81 </manifest> 79def test_forall_all_projects_called_once(tmp_path: Path) -> None:
82 """ 80 """Test that all projects get a command run once each."""
83 with open(self.manifest_file, "w", encoding="utf-8") as fp: 81 manifest = _create_manifest_with_8_projects(tmp_path)
84 fp.write(manifest_data) 82
85 83 cmd = subcmds.forall.Forall()
86 # Set up 8 empty projects to match the manifest 84 cmd.manifest = manifest
87 for x in range(1, 9): 85
88 os.makedirs( 86 # Use echo project names as the test of forall.
89 os.path.join( 87 opts, args = cmd.OptionParser.parse_args(["-c", "echo $REPO_PROJECT"])
90 self.repodir, "projects/tests/path" + str(x) + ".git" 88 opts.verbose = False
91 ) 89
92 ) 90 with contextlib.redirect_stdout(io.StringIO()) as stdout:
93 os.makedirs( 91 # Mock to not have the Execute fail on remote check.
94 os.path.join(
95 self.repodir, "project-objects/project" + str(x) + ".git"
96 )
97 )
98 git_path = os.path.join(self.tempdir, "tests/path" + str(x))
99 utils_for_test.init_git_tree(git_path)
100
101 return manifest_xml.XmlManifest(self.repodir, self.manifest_file)
102
103 # Use mock to capture stdout from the forall run
104 @unittest.mock.patch("sys.stdout", new_callable=StringIO)
105 def test_forall_all_projects_called_once(self, mock_stdout):
106 """Test that all projects get a command run once each."""
107
108 manifest_with_8_projects = self.getXmlManifestWith8Projects()
109
110 cmd = subcmds.forall.Forall()
111 cmd.manifest = manifest_with_8_projects
112
113 # Use echo project names as the test of forall
114 opts, args = cmd.OptionParser.parse_args(["-c", "echo $REPO_PROJECT"])
115 opts.verbose = False
116
117 # Mock to not have the Execute fail on remote check
118 with mock.patch.object( 92 with mock.patch.object(
119 project.Project, "GetRevisionId", return_value="refs/heads/main" 93 project.Project, "GetRevisionId", return_value="refs/heads/main"
120 ): 94 ):
121 # Run the forall command 95 # Run the forall command.
122 cmd.Execute(opts, args) 96 cmd.Execute(opts, args)
123 97
124 # Verify that we got every project name in the prints 98 output = stdout.getvalue()
125 for x in range(1, 9): 99 # Verify that we got every project name in the output.
126 self.assertIn("project" + str(x), mock_stdout.getvalue()) 100 for x in range(1, 9):
127 101 assert f"project{x}" in output
128 # Split the captured output into lines to count them 102
129 line_count = 0 103 # Split the captured output into lines to count them.
130 for line in mock_stdout.getvalue().split("\n"): 104 line_count = sum(1 for x in output.splitlines() if x)
131 # A commented out print to stderr as a reminder 105 # Verify that we didn't get more lines than expected.
132 # that stdout is mocked, include sys and uncomment if needed 106 assert line_count == 8
133 # print(line, file=sys.stderr)
134 if len(line) > 0:
135 line_count += 1
136
137 # Verify that we didn't get more lines than expected
138 assert line_count == 8
diff --git a/tests/test_subcmds_sync.py b/tests/test_subcmds_sync.py
index 9fef68425..b3726d2d1 100644
--- a/tests/test_subcmds_sync.py
+++ b/tests/test_subcmds_sync.py
@@ -477,6 +477,115 @@ class GetPreciousObjectsState(unittest.TestCase):
477 ) 477 )
478 478
479 479
480class CheckForBloatedProjects(unittest.TestCase):
481 """Tests for Sync._CheckForBloatedProjects."""
482
483 def setUp(self):
484 self.cmd = sync.Sync()
485 self.opt = mock.Mock()
486 self.opt.quiet = True
487 self.opt.jobs = 1
488 self.project = mock.MagicMock(clone_depth="1")
489 self.project.name = "project"
490 self.project.Exists = True
491 self.project.worktree = "worktree"
492 self.cmd.git_event_log = mock.MagicMock()
493 self.cmd._bloated_projects = []
494
495 @mock.patch("subcmds.sync.git_require")
496 def test_git_version_unsupported(self, mock_git_require):
497 """Test that it returns early if git version is unsupported."""
498 mock_git_require.return_value = False
499 self.cmd._CheckForBloatedProjects([self.project], self.opt)
500 self.assertFalse(self.cmd.git_event_log.ErrorEvent.called)
501
502 @mock.patch("subcmds.sync.git_require")
503 def test_no_projects(self, mock_git_require):
504 """Test that it returns early if no projects have clone_depth."""
505 mock_git_require.return_value = True
506 self.project.clone_depth = None
507 self.cmd._CheckForBloatedProjects([self.project], self.opt)
508 self.assertFalse(self.cmd.git_event_log.ErrorEvent.called)
509
510 @mock.patch("subcmds.sync.git_require")
511 @mock.patch("subcmds.sync.Progress")
512 def test_bloated_project_found(self, mock_progress, mock_git_require):
513 """Test that it adds project to _bloated_projects."""
514 mock_git_require.return_value = True
515
516 self.cmd.get_parallel_context = mock.Mock(
517 return_value={"projects": [self.project]}
518 )
519
520 def mock_execute_in_parallel(
521 jobs, func, work_items, callback, **kwargs
522 ):
523 callback(None, mock.Mock(), ["project"])
524 return True
525
526 self.cmd.ExecuteInParallel = mock_execute_in_parallel
527
528 with mock.patch.object(self.cmd, "ParallelContext"):
529 self.cmd._CheckForBloatedProjects([self.project], self.opt)
530
531 self.assertEqual(self.cmd._bloated_projects, ["project"])
532
533
534class GCProjectsTest(unittest.TestCase):
535 """Tests for Sync._GCProjects."""
536
537 def setUp(self):
538 self.cmd = sync.Sync()
539 self.opt = mock.Mock()
540 self.opt.quiet = True
541 self.opt.auto_gc = True
542 self.opt.jobs = 1
543 self.project = mock.MagicMock()
544 self.project.name = "project"
545 self.project.objdir = "objdir"
546 self.project.gitdir = "gitdir"
547 self.project.bare_git = mock.MagicMock()
548 self.project.bare_git._project = self.project
549 self.cmd.git_event_log = mock.MagicMock()
550
551 @mock.patch("subcmds.sync.Progress")
552 def test_GCProjects_skip_gc(self, mock_progress):
553 """Test that it skips GC if opt.auto_gc is False."""
554 self.opt.auto_gc = False
555 with mock.patch.object(
556 sync.Sync, "_SetPreciousObjectsState"
557 ) as mock_set_state:
558 self.cmd._GCProjects([self.project], self.opt, None)
559 mock_set_state.assert_called_once_with(self.project, self.opt)
560 self.assertFalse(self.project.bare_git.gc.called)
561
562 @mock.patch("subcmds.sync.Progress")
563 def test_GCProjects_sequential(self, mock_progress):
564 """Test sequential GC (jobs < 2)."""
565 with mock.patch.object(sync.Sync, "_SetPreciousObjectsState"):
566 self.cmd._GCProjects([self.project], self.opt, None)
567 self.project.bare_git.gc.assert_called_once_with(
568 "--auto", config={"gc.autoDetach": "false"}
569 )
570 # Verify that gc.autoDetach was not permanently set in config.
571 for call in self.project.config.SetString.call_args_list:
572 self.assertNotEqual(call.args[0], "gc.autoDetach")
573
574 @mock.patch("subcmds.sync.Progress")
575 def test_GCProjects_parallel(self, mock_progress):
576 """Test parallel GC (jobs >= 2)."""
577 self.opt.jobs = 2
578 with mock.patch.object(sync.Sync, "_SetPreciousObjectsState"):
579 with mock.patch("subcmds.sync._threading.Thread") as mock_thread:
580 mock_t = mock.MagicMock()
581 mock_thread.return_value = mock_t
582 err_event = mock.Mock()
583 err_event.is_set.return_value = False
584 self.cmd._GCProjects([self.project], self.opt, err_event)
585
586 self.assertTrue(mock_thread.called)
587
588
480class SyncCommand(unittest.TestCase): 589class SyncCommand(unittest.TestCase):
481 """Tests for cmd.Execute.""" 590 """Tests for cmd.Execute."""
482 591
diff --git a/tests/test_subcmds_upload.py b/tests/test_subcmds_upload.py
index cd8889778..51c0a4cb7 100644
--- a/tests/test_subcmds_upload.py
+++ b/tests/test_subcmds_upload.py
@@ -14,9 +14,10 @@
14 14
15"""Unittests for the subcmds/upload.py module.""" 15"""Unittests for the subcmds/upload.py module."""
16 16
17import unittest
18from unittest import mock 17from unittest import mock
19 18
19import pytest
20
20from error import GitError 21from error import GitError
21from error import UploadError 22from error import UploadError
22from subcmds import upload 23from subcmds import upload
@@ -26,45 +27,39 @@ class UnexpectedError(Exception):
26 """An exception not expected by upload command.""" 27 """An exception not expected by upload command."""
27 28
28 29
29class UploadCommand(unittest.TestCase): 30# A stub people list (reviewers, cc).
30 """Check registered all_commands.""" 31_STUB_PEOPLE = ([], [])
31 32
32 def setUp(self): 33
33 self.cmd = upload.Upload() 34@pytest.fixture
34 self.branch = mock.MagicMock() 35def cmd() -> upload.Upload:
35 self.people = mock.MagicMock() 36 """Fixture to provide an Upload command instance with mocked methods."""
36 self.opt, _ = self.cmd.OptionParser.parse_args([]) 37 cmd = upload.Upload()
37 mock.patch.object( 38 with mock.patch.object(
38 self.cmd, "_AppendAutoList", return_value=None 39 cmd, "_AppendAutoList", return_value=None
39 ).start() 40 ), mock.patch.object(cmd, "git_event_log"):
40 mock.patch.object(self.cmd, "git_event_log").start() 41 yield cmd
41 42
42 def tearDown(self): 43
43 mock.patch.stopall() 44def test_UploadAndReport_UploadError(cmd: upload.Upload) -> None:
44 45 """Check UploadExitError raised when UploadError encountered."""
45 def test_UploadAndReport_UploadError(self): 46 opt, _ = cmd.OptionParser.parse_args([])
46 """Check UploadExitError raised when UploadError encountered.""" 47 with mock.patch.object(cmd, "_UploadBranch", side_effect=UploadError("")):
47 side_effect = UploadError("upload error") 48 with pytest.raises(upload.UploadExitError):
48 with mock.patch.object( 49 cmd._UploadAndReport(opt, [mock.MagicMock()], _STUB_PEOPLE)
49 self.cmd, "_UploadBranch", side_effect=side_effect 50
50 ): 51
51 with self.assertRaises(upload.UploadExitError): 52def test_UploadAndReport_GitError(cmd: upload.Upload) -> None:
52 self.cmd._UploadAndReport(self.opt, [self.branch], self.people) 53 """Check UploadExitError raised when GitError encountered."""
53 54 opt, _ = cmd.OptionParser.parse_args([])
54 def test_UploadAndReport_GitError(self): 55 with mock.patch.object(cmd, "_UploadBranch", side_effect=GitError("")):
55 """Check UploadExitError raised when GitError encountered.""" 56 with pytest.raises(upload.UploadExitError):
56 side_effect = GitError("some git error") 57 cmd._UploadAndReport(opt, [mock.MagicMock()], _STUB_PEOPLE)
57 with mock.patch.object( 58
58 self.cmd, "_UploadBranch", side_effect=side_effect 59
59 ): 60def test_UploadAndReport_UnhandledError(cmd: upload.Upload) -> None:
60 with self.assertRaises(upload.UploadExitError): 61 """Check UnexpectedError passed through."""
61 self.cmd._UploadAndReport(self.opt, [self.branch], self.people) 62 opt, _ = cmd.OptionParser.parse_args([])
62 63 with mock.patch.object(cmd, "_UploadBranch", side_effect=UnexpectedError):
63 def test_UploadAndReport_UnhandledError(self): 64 with pytest.raises(UnexpectedError):
64 """Check UnexpectedError passed through.""" 65 cmd._UploadAndReport(opt, [mock.MagicMock()], _STUB_PEOPLE)
65 side_effect = UnexpectedError("some os error")
66 with mock.patch.object(
67 self.cmd, "_UploadBranch", side_effect=side_effect
68 ):
69 with self.assertRaises(type(side_effect)):
70 self.cmd._UploadAndReport(self.opt, [self.branch], self.people)
diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py
index a38705675..12c30c786 100644
--- a/tests/test_wrapper.py
+++ b/tests/test_wrapper.py
@@ -19,267 +19,303 @@ import os
19import re 19import re
20import subprocess 20import subprocess
21import sys 21import sys
22import tempfile
23import unittest
24from unittest import mock 22from unittest import mock
25 23
24import pytest
26import utils_for_test 25import utils_for_test
27 26
28import main 27import main
29import wrapper 28import wrapper
30 29
31 30
32def fixture(*paths): 31@pytest.fixture(autouse=True)
33 """Return a path relative to tests/fixtures.""" 32def reset_wrapper() -> None:
34 return os.path.join(os.path.dirname(__file__), "fixtures", *paths) 33 """Reset the wrapper module every time."""
34 wrapper.Wrapper.cache_clear()
35 35
36 36
37class RepoWrapperTestCase(unittest.TestCase): 37@pytest.fixture
38 """TestCase for the wrapper module.""" 38def repo_wrapper() -> wrapper.Wrapper:
39 """Fixture for the wrapper module."""
40 return wrapper.Wrapper()
39 41
40 def setUp(self):
41 """Load the wrapper module every time."""
42 wrapper.Wrapper.cache_clear()
43 self.wrapper = wrapper.Wrapper()
44 42
43class GitCheckout:
44 """Class to hold git checkout info for tests."""
45 45
46class RepoWrapperUnitTest(RepoWrapperTestCase): 46 def __init__(self, git_dir, rev_list):
47 self.git_dir = git_dir
48 self.rev_list = rev_list
49
50
51@pytest.fixture(scope="module")
52def git_checkout(tmp_path_factory) -> GitCheckout:
53 """Fixture for tests that use a real/small git checkout.
54
55 Create a repo to operate on, but do it once per-test-run.
56 """
57 tempdir = tmp_path_factory.mktemp("repo-rev-tests")
58 run_git = wrapper.Wrapper().run_git
59
60 remote = os.path.join(tempdir, "remote")
61 os.mkdir(remote)
62
63 utils_for_test.init_git_tree(remote)
64 run_git("commit", "--allow-empty", "-minit", cwd=remote)
65 run_git("branch", "stable", cwd=remote)
66 run_git("tag", "v1.0", cwd=remote)
67 run_git("commit", "--allow-empty", "-m2nd commit", cwd=remote)
68 rev_list = run_git("rev-list", "HEAD", cwd=remote).stdout.splitlines()
69
70 run_git("init", cwd=tempdir)
71 run_git(
72 "fetch",
73 remote,
74 "+refs/heads/*:refs/remotes/origin/*",
75 cwd=tempdir,
76 )
77 yield GitCheckout(tempdir, rev_list)
78
79
80class TestRepoWrapper:
47 """Tests helper functions in the repo wrapper""" 81 """Tests helper functions in the repo wrapper"""
48 82
49 def test_version(self): 83 def test_version(self, repo_wrapper: wrapper.Wrapper) -> None:
50 """Make sure _Version works.""" 84 """Make sure _Version works."""
51 with self.assertRaises(SystemExit) as e: 85 with pytest.raises(SystemExit) as e:
52 with mock.patch("sys.stdout", new_callable=io.StringIO) as stdout: 86 with mock.patch("sys.stdout", new_callable=io.StringIO) as stdout:
53 with mock.patch( 87 with mock.patch(
54 "sys.stderr", new_callable=io.StringIO 88 "sys.stderr", new_callable=io.StringIO
55 ) as stderr: 89 ) as stderr:
56 self.wrapper._Version() 90 repo_wrapper._Version()
57 self.assertEqual(0, e.exception.code) 91 assert e.value.code == 0
58 self.assertEqual("", stderr.getvalue()) 92 assert stderr.getvalue() == ""
59 self.assertIn("repo launcher version", stdout.getvalue()) 93 assert "repo launcher version" in stdout.getvalue()
60 94
61 def test_python_constraints(self): 95 def test_python_constraints(self, repo_wrapper: wrapper.Wrapper) -> None:
62 """The launcher should never require newer than main.py.""" 96 """The launcher should never require newer than main.py."""
63 self.assertGreaterEqual( 97 assert (
64 main.MIN_PYTHON_VERSION_HARD, self.wrapper.MIN_PYTHON_VERSION_HARD 98 main.MIN_PYTHON_VERSION_HARD >= repo_wrapper.MIN_PYTHON_VERSION_HARD
65 ) 99 )
66 self.assertGreaterEqual( 100 assert (
67 main.MIN_PYTHON_VERSION_SOFT, self.wrapper.MIN_PYTHON_VERSION_SOFT 101 main.MIN_PYTHON_VERSION_SOFT >= repo_wrapper.MIN_PYTHON_VERSION_SOFT
68 ) 102 )
69 # Make sure the versions are themselves in sync. 103 # Make sure the versions are themselves in sync.
70 self.assertGreaterEqual( 104 assert (
71 self.wrapper.MIN_PYTHON_VERSION_SOFT, 105 repo_wrapper.MIN_PYTHON_VERSION_SOFT
72 self.wrapper.MIN_PYTHON_VERSION_HARD, 106 >= repo_wrapper.MIN_PYTHON_VERSION_HARD
73 ) 107 )
74 108
75 def test_init_parser(self): 109 def test_repo_script_is_executable(self) -> None:
110 """The repo launcher script should be executable."""
111 repo_path = utils_for_test.THIS_DIR.parent / "repo"
112 assert os.access(repo_path, os.X_OK), f"{repo_path} is not executable"
113
114 def test_init_parser(self, repo_wrapper: wrapper.Wrapper) -> None:
76 """Make sure 'init' GetParser works.""" 115 """Make sure 'init' GetParser works."""
77 parser = self.wrapper.GetParser() 116 parser = repo_wrapper.GetParser()
78 opts, args = parser.parse_args([]) 117 opts, args = parser.parse_args([])
79 self.assertEqual([], args) 118 assert args == []
80 self.assertIsNone(opts.manifest_url) 119 assert opts.manifest_url is None
81 120
82 121
83class SetGitTrace2ParentSid(RepoWrapperTestCase): 122class TestSetGitTrace2ParentSid:
84 """Check SetGitTrace2ParentSid behavior.""" 123 """Check SetGitTrace2ParentSid behavior."""
85 124
86 KEY = "GIT_TRACE2_PARENT_SID" 125 KEY = "GIT_TRACE2_PARENT_SID"
87 VALID_FORMAT = re.compile(r"^repo-[0-9]{8}T[0-9]{6}Z-P[0-9a-f]{8}$") 126 VALID_FORMAT = re.compile(r"^repo-[0-9]{8}T[0-9]{6}Z-P[0-9a-f]{8}$")
88 127
89 def test_first_set(self): 128 def test_first_set(self, repo_wrapper: wrapper.Wrapper) -> None:
90 """Test env var not yet set.""" 129 """Test env var not yet set."""
91 env = {} 130 env = {}
92 self.wrapper.SetGitTrace2ParentSid(env) 131 repo_wrapper.SetGitTrace2ParentSid(env)
93 self.assertIn(self.KEY, env) 132 assert self.KEY in env
94 value = env[self.KEY] 133 value = env[self.KEY]
95 self.assertRegex(value, self.VALID_FORMAT) 134 assert self.VALID_FORMAT.match(value)
96 135
97 def test_append(self): 136 def test_append(self, repo_wrapper: wrapper.Wrapper) -> None:
98 """Test env var is appended.""" 137 """Test env var is appended."""
99 env = {self.KEY: "pfx"} 138 env = {self.KEY: "pfx"}
100 self.wrapper.SetGitTrace2ParentSid(env) 139 repo_wrapper.SetGitTrace2ParentSid(env)
101 self.assertIn(self.KEY, env) 140 assert self.KEY in env
102 value = env[self.KEY] 141 value = env[self.KEY]
103 self.assertTrue(value.startswith("pfx/")) 142 assert value.startswith("pfx/")
104 self.assertRegex(value[4:], self.VALID_FORMAT) 143 assert self.VALID_FORMAT.match(value[4:])
105 144
106 def test_global_context(self): 145 def test_global_context(self, repo_wrapper: wrapper.Wrapper) -> None:
107 """Check os.environ gets updated by default.""" 146 """Check os.environ gets updated by default."""
108 os.environ.pop(self.KEY, None) 147 os.environ.pop(self.KEY, None)
109 self.wrapper.SetGitTrace2ParentSid() 148 repo_wrapper.SetGitTrace2ParentSid()
110 self.assertIn(self.KEY, os.environ) 149 assert self.KEY in os.environ
111 value = os.environ[self.KEY] 150 value = os.environ[self.KEY]
112 self.assertRegex(value, self.VALID_FORMAT) 151 assert self.VALID_FORMAT.match(value)
113 152
114 153
115class RunCommand(RepoWrapperTestCase): 154class TestRunCommand:
116 """Check run_command behavior.""" 155 """Check run_command behavior."""
117 156
118 def test_capture(self): 157 def test_capture(self, repo_wrapper: wrapper.Wrapper) -> None:
119 """Check capture_output handling.""" 158 """Check capture_output handling."""
120 ret = self.wrapper.run_command(["echo", "hi"], capture_output=True) 159 ret = repo_wrapper.run_command(["echo", "hi"], capture_output=True)
121 # echo command appends OS specific linesep, but on Windows + Git Bash 160 # echo command appends OS specific linesep, but on Windows + Git Bash
122 # we get UNIX ending, so we allow both. 161 # we get UNIX ending, so we allow both.
123 self.assertIn(ret.stdout, ["hi" + os.linesep, "hi\n"]) 162 assert ret.stdout in ["hi" + os.linesep, "hi\n"]
124 163
125 def test_check(self): 164 def test_check(self, repo_wrapper: wrapper.Wrapper) -> None:
126 """Check check handling.""" 165 """Check check handling."""
127 self.wrapper.run_command(["true"], check=False) 166 repo_wrapper.run_command(["true"], check=False)
128 self.wrapper.run_command(["true"], check=True) 167 repo_wrapper.run_command(["true"], check=True)
129 self.wrapper.run_command(["false"], check=False) 168 repo_wrapper.run_command(["false"], check=False)
130 with self.assertRaises(subprocess.CalledProcessError): 169 with pytest.raises(subprocess.CalledProcessError):
131 self.wrapper.run_command(["false"], check=True) 170 repo_wrapper.run_command(["false"], check=True)
132 171
133 172
134class RunGit(RepoWrapperTestCase): 173class TestRunGit:
135 """Check run_git behavior.""" 174 """Check run_git behavior."""
136 175
137 def test_capture(self): 176 def test_capture(self, repo_wrapper: wrapper.Wrapper) -> None:
138 """Check capture_output handling.""" 177 """Check capture_output handling."""
139 ret = self.wrapper.run_git("--version") 178 ret = repo_wrapper.run_git("--version")
140 self.assertIn("git", ret.stdout) 179 assert "git" in ret.stdout
141 180
142 def test_check(self): 181 def test_check(self, repo_wrapper: wrapper.Wrapper) -> None:
143 """Check check handling.""" 182 """Check check handling."""
144 with self.assertRaises(self.wrapper.CloneFailure): 183 with pytest.raises(repo_wrapper.CloneFailure):
145 self.wrapper.run_git("--version-asdfasdf") 184 repo_wrapper.run_git("--version-asdfasdf")
146 self.wrapper.run_git("--version-asdfasdf", check=False) 185 repo_wrapper.run_git("--version-asdfasdf", check=False)
147 186
148 187
149class ParseGitVersion(RepoWrapperTestCase): 188class TestParseGitVersion:
150 """Check ParseGitVersion behavior.""" 189 """Check ParseGitVersion behavior."""
151 190
152 def test_autoload(self): 191 def test_autoload(self, repo_wrapper: wrapper.Wrapper) -> None:
153 """Check we can load the version from the live git.""" 192 """Check we can load the version from the live git."""
154 ret = self.wrapper.ParseGitVersion() 193 assert repo_wrapper.ParseGitVersion() is not None
155 self.assertIsNotNone(ret)
156 194
157 def test_bad_ver(self): 195 def test_bad_ver(self, repo_wrapper: wrapper.Wrapper) -> None:
158 """Check handling of bad git versions.""" 196 """Check handling of bad git versions."""
159 ret = self.wrapper.ParseGitVersion(ver_str="asdf") 197 assert repo_wrapper.ParseGitVersion(ver_str="asdf") is None
160 self.assertIsNone(ret)
161 198
162 def test_normal_ver(self): 199 def test_normal_ver(self, repo_wrapper: wrapper.Wrapper) -> None:
163 """Check handling of normal git versions.""" 200 """Check handling of normal git versions."""
164 ret = self.wrapper.ParseGitVersion(ver_str="git version 2.25.1") 201 ret = repo_wrapper.ParseGitVersion(ver_str="git version 2.25.1")
165 self.assertEqual(2, ret.major) 202 assert ret.major == 2
166 self.assertEqual(25, ret.minor) 203 assert ret.minor == 25
167 self.assertEqual(1, ret.micro) 204 assert ret.micro == 1
168 self.assertEqual("2.25.1", ret.full) 205 assert ret.full == "2.25.1"
169 206
170 def test_extended_ver(self): 207 def test_extended_ver(self, repo_wrapper: wrapper.Wrapper) -> None:
171 """Check handling of extended distro git versions.""" 208 """Check handling of extended distro git versions."""
172 ret = self.wrapper.ParseGitVersion( 209 ret = repo_wrapper.ParseGitVersion(
173 ver_str="git version 1.30.50.696.g5e7596f4ac-goog" 210 ver_str="git version 1.30.50.696.g5e7596f4ac-goog"
174 ) 211 )
175 self.assertEqual(1, ret.major) 212 assert ret.major == 1
176 self.assertEqual(30, ret.minor) 213 assert ret.minor == 30
177 self.assertEqual(50, ret.micro) 214 assert ret.micro == 50
178 self.assertEqual("1.30.50.696.g5e7596f4ac-goog", ret.full) 215 assert ret.full == "1.30.50.696.g5e7596f4ac-goog"
179 216
180 217
181class CheckGitVersion(RepoWrapperTestCase): 218class TestCheckGitVersion:
182 """Check _CheckGitVersion behavior.""" 219 """Check _CheckGitVersion behavior."""
183 220
184 def test_unknown(self): 221 def test_unknown(self, repo_wrapper: wrapper.Wrapper) -> None:
185 """Unknown versions should abort.""" 222 """Unknown versions should abort."""
186 with mock.patch.object( 223 with mock.patch.object(
187 self.wrapper, "ParseGitVersion", return_value=None 224 repo_wrapper, "ParseGitVersion", return_value=None
188 ): 225 ):
189 with self.assertRaises(self.wrapper.CloneFailure): 226 with pytest.raises(repo_wrapper.CloneFailure):
190 self.wrapper._CheckGitVersion() 227 repo_wrapper._CheckGitVersion()
191 228
192 def test_old(self): 229 def test_old(self, repo_wrapper: wrapper.Wrapper) -> None:
193 """Old versions should abort.""" 230 """Old versions should abort."""
194 with mock.patch.object( 231 with mock.patch.object(
195 self.wrapper, 232 repo_wrapper,
196 "ParseGitVersion", 233 "ParseGitVersion",
197 return_value=self.wrapper.GitVersion(1, 0, 0, "1.0.0"), 234 return_value=repo_wrapper.GitVersion(1, 0, 0, "1.0.0"),
198 ): 235 ):
199 with self.assertRaises(self.wrapper.CloneFailure): 236 with pytest.raises(repo_wrapper.CloneFailure):
200 self.wrapper._CheckGitVersion() 237 repo_wrapper._CheckGitVersion()
201 238
202 def test_new(self): 239 def test_new(self, repo_wrapper: wrapper.Wrapper) -> None:
203 """Newer versions should run fine.""" 240 """Newer versions should run fine."""
204 with mock.patch.object( 241 with mock.patch.object(
205 self.wrapper, 242 repo_wrapper,
206 "ParseGitVersion", 243 "ParseGitVersion",
207 return_value=self.wrapper.GitVersion(100, 0, 0, "100.0.0"), 244 return_value=repo_wrapper.GitVersion(100, 0, 0, "100.0.0"),
208 ): 245 ):
209 self.wrapper._CheckGitVersion() 246 repo_wrapper._CheckGitVersion()
210 247
211 248
212class Requirements(RepoWrapperTestCase): 249class TestRequirements:
213 """Check Requirements handling.""" 250 """Check Requirements handling."""
214 251
215 def test_missing_file(self): 252 def test_missing_file(self, repo_wrapper: wrapper.Wrapper) -> None:
216 """Don't crash if the file is missing (old version).""" 253 """Don't crash if the file is missing (old version)."""
217 testdir = os.path.dirname(os.path.realpath(__file__)) 254 assert (
218 self.assertIsNone(self.wrapper.Requirements.from_dir(testdir)) 255 repo_wrapper.Requirements.from_dir(utils_for_test.THIS_DIR) is None
219 self.assertIsNone( 256 )
220 self.wrapper.Requirements.from_file( 257 assert (
221 os.path.join(testdir, "xxxxxxxxxxxxxxxxxxxxxxxx") 258 repo_wrapper.Requirements.from_file(
259 utils_for_test.THIS_DIR / "xxxxxxxxxxxxxxxxxxxxxxxx"
222 ) 260 )
261 is None
223 ) 262 )
224 263
225 def test_corrupt_data(self): 264 def test_corrupt_data(self, repo_wrapper: wrapper.Wrapper) -> None:
226 """If the file can't be parsed, don't blow up.""" 265 """If the file can't be parsed, don't blow up."""
227 self.assertIsNone(self.wrapper.Requirements.from_file(__file__)) 266 assert repo_wrapper.Requirements.from_file(__file__) is None
228 self.assertIsNone(self.wrapper.Requirements.from_data(b"x")) 267 assert repo_wrapper.Requirements.from_data(b"x") is None
229 268
230 def test_valid_data(self): 269 def test_valid_data(self, repo_wrapper: wrapper.Wrapper) -> None:
231 """Make sure we can parse the file we ship.""" 270 """Make sure we can parse the file we ship."""
232 self.assertIsNotNone(self.wrapper.Requirements.from_data(b"{}")) 271 assert repo_wrapper.Requirements.from_data(b"{}") is not None
233 rootdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 272 rootdir = utils_for_test.THIS_DIR.parent
234 self.assertIsNotNone(self.wrapper.Requirements.from_dir(rootdir)) 273 assert repo_wrapper.Requirements.from_dir(rootdir) is not None
235 self.assertIsNotNone( 274 assert (
236 self.wrapper.Requirements.from_file( 275 repo_wrapper.Requirements.from_file(rootdir / "requirements.json")
237 os.path.join(rootdir, "requirements.json") 276 is not None
238 )
239 ) 277 )
240 278
241 def test_format_ver(self): 279 def test_format_ver(self, repo_wrapper: wrapper.Wrapper) -> None:
242 """Check format_ver can format.""" 280 """Check format_ver can format."""
243 self.assertEqual( 281 assert repo_wrapper.Requirements._format_ver((1, 2, 3)) == "1.2.3"
244 "1.2.3", self.wrapper.Requirements._format_ver((1, 2, 3)) 282 assert repo_wrapper.Requirements._format_ver([1]) == "1"
245 )
246 self.assertEqual("1", self.wrapper.Requirements._format_ver([1]))
247 283
248 def test_assert_all_unknown(self): 284 def test_assert_all_unknown(self, repo_wrapper: wrapper.Wrapper) -> None:
249 """Check assert_all works with incompatible file.""" 285 """Check assert_all works with incompatible file."""
250 reqs = self.wrapper.Requirements({}) 286 reqs = repo_wrapper.Requirements({})
251 reqs.assert_all() 287 reqs.assert_all()
252 288
253 def test_assert_all_new_repo(self): 289 def test_assert_all_new_repo(self, repo_wrapper: wrapper.Wrapper) -> None:
254 """Check assert_all accepts new enough repo.""" 290 """Check assert_all accepts new enough repo."""
255 reqs = self.wrapper.Requirements({"repo": {"hard": [1, 0]}}) 291 reqs = repo_wrapper.Requirements({"repo": {"hard": [1, 0]}})
256 reqs.assert_all() 292 reqs.assert_all()
257 293
258 def test_assert_all_old_repo(self): 294 def test_assert_all_old_repo(self, repo_wrapper: wrapper.Wrapper) -> None:
259 """Check assert_all rejects old repo.""" 295 """Check assert_all rejects old repo."""
260 reqs = self.wrapper.Requirements({"repo": {"hard": [99999, 0]}}) 296 reqs = repo_wrapper.Requirements({"repo": {"hard": [99999, 0]}})
261 with self.assertRaises(SystemExit): 297 with pytest.raises(SystemExit):
262 reqs.assert_all() 298 reqs.assert_all()
263 299
264 def test_assert_all_new_python(self): 300 def test_assert_all_new_python(self, repo_wrapper: wrapper.Wrapper) -> None:
265 """Check assert_all accepts new enough python.""" 301 """Check assert_all accepts new enough python."""
266 reqs = self.wrapper.Requirements({"python": {"hard": sys.version_info}}) 302 reqs = repo_wrapper.Requirements({"python": {"hard": sys.version_info}})
267 reqs.assert_all() 303 reqs.assert_all()
268 304
269 def test_assert_all_old_python(self): 305 def test_assert_all_old_python(self, repo_wrapper: wrapper.Wrapper) -> None:
270 """Check assert_all rejects old python.""" 306 """Check assert_all rejects old python."""
271 reqs = self.wrapper.Requirements({"python": {"hard": [99999, 0]}}) 307 reqs = repo_wrapper.Requirements({"python": {"hard": [99999, 0]}})
272 with self.assertRaises(SystemExit): 308 with pytest.raises(SystemExit):
273 reqs.assert_all() 309 reqs.assert_all()
274 310
275 def test_assert_ver_unknown(self): 311 def test_assert_ver_unknown(self, repo_wrapper: wrapper.Wrapper) -> None:
276 """Check assert_ver works with incompatible file.""" 312 """Check assert_ver works with incompatible file."""
277 reqs = self.wrapper.Requirements({}) 313 reqs = repo_wrapper.Requirements({})
278 reqs.assert_ver("xxx", (1, 0)) 314 reqs.assert_ver("xxx", (1, 0))
279 315
280 def test_assert_ver_new(self): 316 def test_assert_ver_new(self, repo_wrapper: wrapper.Wrapper) -> None:
281 """Check assert_ver allows new enough versions.""" 317 """Check assert_ver allows new enough versions."""
282 reqs = self.wrapper.Requirements( 318 reqs = repo_wrapper.Requirements(
283 {"git": {"hard": [1, 0], "soft": [2, 0]}} 319 {"git": {"hard": [1, 0], "soft": [2, 0]}}
284 ) 320 )
285 reqs.assert_ver("git", (1, 0)) 321 reqs.assert_ver("git", (1, 0))
@@ -287,274 +323,279 @@ class Requirements(RepoWrapperTestCase):
287 reqs.assert_ver("git", (2, 0)) 323 reqs.assert_ver("git", (2, 0))
288 reqs.assert_ver("git", (2, 5)) 324 reqs.assert_ver("git", (2, 5))
289 325
290 def test_assert_ver_old(self): 326 def test_assert_ver_old(self, repo_wrapper: wrapper.Wrapper) -> None:
291 """Check assert_ver rejects old versions.""" 327 """Check assert_ver rejects old versions."""
292 reqs = self.wrapper.Requirements( 328 reqs = repo_wrapper.Requirements(
293 {"git": {"hard": [1, 0], "soft": [2, 0]}} 329 {"git": {"hard": [1, 0], "soft": [2, 0]}}
294 ) 330 )
295 with self.assertRaises(SystemExit): 331 with pytest.raises(SystemExit):
296 reqs.assert_ver("git", (0, 5)) 332 reqs.assert_ver("git", (0, 5))
297 333
298 334
299class NeedSetupGnuPG(RepoWrapperTestCase): 335class TestNeedSetupGnuPG:
300 """Check NeedSetupGnuPG behavior.""" 336 """Check NeedSetupGnuPG behavior."""
301 337
302 def test_missing_dir(self): 338 def test_missing_dir(self, tmp_path, repo_wrapper: wrapper.Wrapper) -> None:
303 """The ~/.repoconfig tree doesn't exist yet.""" 339 """The ~/.repoconfig tree doesn't exist yet."""
304 with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: 340 repo_wrapper.home_dot_repo = str(tmp_path / "foo")
305 self.wrapper.home_dot_repo = os.path.join(tempdir, "foo") 341 assert repo_wrapper.NeedSetupGnuPG()
306 self.assertTrue(self.wrapper.NeedSetupGnuPG())
307 342
308 def test_missing_keyring(self): 343 def test_missing_keyring(
344 self, tmp_path, repo_wrapper: wrapper.Wrapper
345 ) -> None:
309 """The keyring-version file doesn't exist yet.""" 346 """The keyring-version file doesn't exist yet."""
310 with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: 347 repo_wrapper.home_dot_repo = str(tmp_path)
311 self.wrapper.home_dot_repo = tempdir 348 assert repo_wrapper.NeedSetupGnuPG()
312 self.assertTrue(self.wrapper.NeedSetupGnuPG())
313 349
314 def test_empty_keyring(self): 350 def test_empty_keyring(
351 self, tmp_path, repo_wrapper: wrapper.Wrapper
352 ) -> None:
315 """The keyring-version file exists, but is empty.""" 353 """The keyring-version file exists, but is empty."""
316 with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: 354 repo_wrapper.home_dot_repo = str(tmp_path)
317 self.wrapper.home_dot_repo = tempdir 355 (tmp_path / "keyring-version").write_text("")
318 with open(os.path.join(tempdir, "keyring-version"), "w"): 356 assert repo_wrapper.NeedSetupGnuPG()
319 pass
320 self.assertTrue(self.wrapper.NeedSetupGnuPG())
321 357
322 def test_old_keyring(self): 358 def test_old_keyring(self, tmp_path, repo_wrapper: wrapper.Wrapper) -> None:
323 """The keyring-version file exists, but it's old.""" 359 """The keyring-version file exists, but it's old."""
324 with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: 360 repo_wrapper.home_dot_repo = str(tmp_path)
325 self.wrapper.home_dot_repo = tempdir 361 (tmp_path / "keyring-version").write_text("1.0\n")
326 with open(os.path.join(tempdir, "keyring-version"), "w") as fp: 362 assert repo_wrapper.NeedSetupGnuPG()
327 fp.write("1.0\n")
328 self.assertTrue(self.wrapper.NeedSetupGnuPG())
329 363
330 def test_new_keyring(self): 364 def test_new_keyring(self, tmp_path, repo_wrapper: wrapper.Wrapper) -> None:
331 """The keyring-version file exists, and is up-to-date.""" 365 """The keyring-version file exists, and is up-to-date."""
332 with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: 366 repo_wrapper.home_dot_repo = str(tmp_path)
333 self.wrapper.home_dot_repo = tempdir 367 (tmp_path / "keyring-version").write_text("1000.0\n")
334 with open(os.path.join(tempdir, "keyring-version"), "w") as fp: 368 assert not repo_wrapper.NeedSetupGnuPG()
335 fp.write("1000.0\n")
336 self.assertFalse(self.wrapper.NeedSetupGnuPG())
337 369
338 370
339class SetupGnuPG(RepoWrapperTestCase): 371class TestSetupGnuPG:
340 """Check SetupGnuPG behavior.""" 372 """Check SetupGnuPG behavior."""
341 373
342 def test_full(self): 374 def test_full(self, tmp_path, repo_wrapper: wrapper.Wrapper) -> None:
343 """Make sure it works completely.""" 375 """Make sure it works completely."""
344 with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: 376 repo_wrapper.home_dot_repo = str(tmp_path)
345 self.wrapper.home_dot_repo = tempdir 377 repo_wrapper.gpg_dir = str(tmp_path / "gnupg")
346 self.wrapper.gpg_dir = os.path.join( 378 assert repo_wrapper.SetupGnuPG(True)
347 self.wrapper.home_dot_repo, "gnupg" 379 data = (tmp_path / "keyring-version").read_text()
348 ) 380 assert (
349 self.assertTrue(self.wrapper.SetupGnuPG(True)) 381 ".".join(str(x) for x in repo_wrapper.KEYRING_VERSION)
350 with open(os.path.join(tempdir, "keyring-version")) as fp: 382 == data.strip()
351 data = fp.read() 383 )
352 self.assertEqual(
353 ".".join(str(x) for x in self.wrapper.KEYRING_VERSION),
354 data.strip(),
355 )
356 384
357 385
358class VerifyRev(RepoWrapperTestCase): 386class TestVerifyRev:
359 """Check verify_rev behavior.""" 387 """Check verify_rev behavior."""
360 388
361 def test_verify_passes(self): 389 def test_verify_passes(self, repo_wrapper: wrapper.Wrapper) -> None:
362 """Check when we have a valid signed tag.""" 390 """Check when we have a valid signed tag."""
363 desc_result = subprocess.CompletedProcess([], 0, "v1.0\n", "") 391 desc_result = subprocess.CompletedProcess([], 0, "v1.0\n", "")
364 gpg_result = subprocess.CompletedProcess([], 0, "", "") 392 gpg_result = subprocess.CompletedProcess([], 0, "", "")
365 with mock.patch.object( 393 with mock.patch.object(
366 self.wrapper, "run_git", side_effect=(desc_result, gpg_result) 394 repo_wrapper, "run_git", side_effect=(desc_result, gpg_result)
367 ): 395 ):
368 ret = self.wrapper.verify_rev( 396 ret = repo_wrapper.verify_rev(
369 "/", "refs/heads/stable", "1234", True 397 "/", "refs/heads/stable", "1234", True
370 ) 398 )
371 self.assertEqual("v1.0^0", ret) 399 assert ret == "v1.0^0"
372 400
373 def test_unsigned_commit(self): 401 def test_unsigned_commit(self, repo_wrapper: wrapper.Wrapper) -> None:
374 """Check we fall back to signed tag when we have an unsigned commit.""" 402 """Check we fall back to signed tag when we have an unsigned commit."""
375 desc_result = subprocess.CompletedProcess([], 0, "v1.0-10-g1234\n", "") 403 desc_result = subprocess.CompletedProcess([], 0, "v1.0-10-g1234\n", "")
376 gpg_result = subprocess.CompletedProcess([], 0, "", "") 404 gpg_result = subprocess.CompletedProcess([], 0, "", "")
377 with mock.patch.object( 405 with mock.patch.object(
378 self.wrapper, "run_git", side_effect=(desc_result, gpg_result) 406 repo_wrapper, "run_git", side_effect=(desc_result, gpg_result)
379 ): 407 ):
380 ret = self.wrapper.verify_rev( 408 ret = repo_wrapper.verify_rev(
381 "/", "refs/heads/stable", "1234", True 409 "/", "refs/heads/stable", "1234", True
382 ) 410 )
383 self.assertEqual("v1.0^0", ret) 411 assert ret == "v1.0^0"
384 412
385 def test_verify_fails(self): 413 def test_verify_fails(self, repo_wrapper: wrapper.Wrapper) -> None:
386 """Check we fall back to signed tag when we have an unsigned commit.""" 414 """Check we fall back to signed tag when we have an unsigned commit."""
387 desc_result = subprocess.CompletedProcess([], 0, "v1.0-10-g1234\n", "") 415 desc_result = subprocess.CompletedProcess([], 0, "v1.0-10-g1234\n", "")
388 gpg_result = Exception 416 gpg_result = RuntimeError
389 with mock.patch.object( 417 with mock.patch.object(
390 self.wrapper, "run_git", side_effect=(desc_result, gpg_result) 418 repo_wrapper, "run_git", side_effect=(desc_result, gpg_result)
391 ): 419 ):
392 with self.assertRaises(Exception): 420 with pytest.raises(RuntimeError):
393 self.wrapper.verify_rev("/", "refs/heads/stable", "1234", True) 421 repo_wrapper.verify_rev("/", "refs/heads/stable", "1234", True)
394
395
396class GitCheckoutTestCase(RepoWrapperTestCase):
397 """Tests that use a real/small git checkout."""
398
399 GIT_DIR = None
400 REV_LIST = None
401
402 @classmethod
403 def setUpClass(cls):
404 # Create a repo to operate on, but do it once per-class.
405 cls.tempdirobj = tempfile.TemporaryDirectory(prefix="repo-rev-tests")
406 cls.GIT_DIR = cls.tempdirobj.name
407 run_git = wrapper.Wrapper().run_git
408
409 remote = os.path.join(cls.GIT_DIR, "remote")
410 os.mkdir(remote)
411
412 utils_for_test.init_git_tree(remote)
413 run_git("commit", "--allow-empty", "-minit", cwd=remote)
414 run_git("branch", "stable", cwd=remote)
415 run_git("tag", "v1.0", cwd=remote)
416 run_git("commit", "--allow-empty", "-m2nd commit", cwd=remote)
417 cls.REV_LIST = run_git(
418 "rev-list", "HEAD", cwd=remote
419 ).stdout.splitlines()
420
421 run_git("init", cwd=cls.GIT_DIR)
422 run_git(
423 "fetch",
424 remote,
425 "+refs/heads/*:refs/remotes/origin/*",
426 cwd=cls.GIT_DIR,
427 )
428
429 @classmethod
430 def tearDownClass(cls):
431 if not cls.tempdirobj:
432 return
433
434 cls.tempdirobj.cleanup()
435 422
436 423
437class ResolveRepoRev(GitCheckoutTestCase): 424class TestResolveRepoRev:
438 """Check resolve_repo_rev behavior.""" 425 """Check resolve_repo_rev behavior."""
439 426
440 def test_explicit_branch(self): 427 def test_explicit_branch(
428 self,
429 repo_wrapper: wrapper.Wrapper,
430 git_checkout: GitCheckout,
431 ) -> None:
441 """Check refs/heads/branch argument.""" 432 """Check refs/heads/branch argument."""
442 rrev, lrev = self.wrapper.resolve_repo_rev( 433 rrev, lrev = repo_wrapper.resolve_repo_rev(
443 self.GIT_DIR, "refs/heads/stable" 434 git_checkout.git_dir, "refs/heads/stable"
444 ) 435 )
445 self.assertEqual("refs/heads/stable", rrev) 436 assert rrev == "refs/heads/stable"
446 self.assertEqual(self.REV_LIST[1], lrev) 437 assert lrev == git_checkout.rev_list[1]
447 438
448 with self.assertRaises(self.wrapper.CloneFailure): 439 with pytest.raises(repo_wrapper.CloneFailure):
449 self.wrapper.resolve_repo_rev(self.GIT_DIR, "refs/heads/unknown") 440 repo_wrapper.resolve_repo_rev(
441 git_checkout.git_dir, "refs/heads/unknown"
442 )
450 443
451 def test_explicit_tag(self): 444 def test_explicit_tag(
445 self,
446 repo_wrapper: wrapper.Wrapper,
447 git_checkout: GitCheckout,
448 ) -> None:
452 """Check refs/tags/tag argument.""" 449 """Check refs/tags/tag argument."""
453 rrev, lrev = self.wrapper.resolve_repo_rev( 450 rrev, lrev = repo_wrapper.resolve_repo_rev(
454 self.GIT_DIR, "refs/tags/v1.0" 451 git_checkout.git_dir, "refs/tags/v1.0"
455 ) 452 )
456 self.assertEqual("refs/tags/v1.0", rrev) 453 assert rrev == "refs/tags/v1.0"
457 self.assertEqual(self.REV_LIST[1], lrev) 454 assert lrev == git_checkout.rev_list[1]
458 455
459 with self.assertRaises(self.wrapper.CloneFailure): 456 with pytest.raises(repo_wrapper.CloneFailure):
460 self.wrapper.resolve_repo_rev(self.GIT_DIR, "refs/tags/unknown") 457 repo_wrapper.resolve_repo_rev(
458 git_checkout.git_dir, "refs/tags/unknown"
459 )
461 460
462 def test_branch_name(self): 461 def test_branch_name(
462 self,
463 repo_wrapper: wrapper.Wrapper,
464 git_checkout: GitCheckout,
465 ) -> None:
463 """Check branch argument.""" 466 """Check branch argument."""
464 rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "stable") 467 rrev, lrev = repo_wrapper.resolve_repo_rev(
465 self.assertEqual("refs/heads/stable", rrev) 468 git_checkout.git_dir, "stable"
466 self.assertEqual(self.REV_LIST[1], lrev) 469 )
467 470 assert rrev == "refs/heads/stable"
468 rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "main") 471 assert lrev == git_checkout.rev_list[1]
469 self.assertEqual("refs/heads/main", rrev) 472
470 self.assertEqual(self.REV_LIST[0], lrev) 473 rrev, lrev = repo_wrapper.resolve_repo_rev(git_checkout.git_dir, "main")
471 474 assert rrev == "refs/heads/main"
472 def test_tag_name(self): 475 assert lrev == git_checkout.rev_list[0]
476
477 def test_tag_name(
478 self,
479 repo_wrapper: wrapper.Wrapper,
480 git_checkout: GitCheckout,
481 ) -> None:
473 """Check tag argument.""" 482 """Check tag argument."""
474 rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "v1.0") 483 rrev, lrev = repo_wrapper.resolve_repo_rev(git_checkout.git_dir, "v1.0")
475 self.assertEqual("refs/tags/v1.0", rrev) 484 assert rrev == "refs/tags/v1.0"
476 self.assertEqual(self.REV_LIST[1], lrev) 485 assert lrev == git_checkout.rev_list[1]
477 486
478 def test_full_commit(self): 487 def test_full_commit(
488 self,
489 repo_wrapper: wrapper.Wrapper,
490 git_checkout: GitCheckout,
491 ) -> None:
479 """Check specific commit argument.""" 492 """Check specific commit argument."""
480 commit = self.REV_LIST[0] 493 commit = git_checkout.rev_list[0]
481 rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, commit) 494 rrev, lrev = repo_wrapper.resolve_repo_rev(git_checkout.git_dir, commit)
482 self.assertEqual(commit, rrev) 495 assert rrev == commit
483 self.assertEqual(commit, lrev) 496 assert lrev == commit
484 497
485 def test_partial_commit(self): 498 def test_partial_commit(
499 self,
500 repo_wrapper: wrapper.Wrapper,
501 git_checkout: GitCheckout,
502 ) -> None:
486 """Check specific (partial) commit argument.""" 503 """Check specific (partial) commit argument."""
487 commit = self.REV_LIST[0][0:20] 504 commit = git_checkout.rev_list[0][0:20]
488 rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, commit) 505 rrev, lrev = repo_wrapper.resolve_repo_rev(git_checkout.git_dir, commit)
489 self.assertEqual(self.REV_LIST[0], rrev) 506 assert rrev == git_checkout.rev_list[0]
490 self.assertEqual(self.REV_LIST[0], lrev) 507 assert lrev == git_checkout.rev_list[0]
491 508
492 def test_unknown(self): 509 def test_unknown(
510 self,
511 repo_wrapper: wrapper.Wrapper,
512 git_checkout: GitCheckout,
513 ) -> None:
493 """Check unknown ref/commit argument.""" 514 """Check unknown ref/commit argument."""
494 with self.assertRaises(self.wrapper.CloneFailure): 515 with pytest.raises(repo_wrapper.CloneFailure):
495 self.wrapper.resolve_repo_rev(self.GIT_DIR, "boooooooya") 516 repo_wrapper.resolve_repo_rev(git_checkout.git_dir, "boooooooya")
496 517
497 518
498class CheckRepoVerify(RepoWrapperTestCase): 519class TestCheckRepoVerify:
499 """Check check_repo_verify behavior.""" 520 """Check check_repo_verify behavior."""
500 521
501 def test_no_verify(self): 522 def test_no_verify(self, repo_wrapper: wrapper.Wrapper) -> None:
502 """Always fail with --no-repo-verify.""" 523 """Always fail with --no-repo-verify."""
503 self.assertFalse(self.wrapper.check_repo_verify(False)) 524 assert not repo_wrapper.check_repo_verify(False)
504 525
505 def test_gpg_initialized(self): 526 def test_gpg_initialized(
527 self,
528 repo_wrapper: wrapper.Wrapper,
529 ) -> None:
506 """Should pass if gpg is setup already.""" 530 """Should pass if gpg is setup already."""
507 with mock.patch.object( 531 with mock.patch.object(
508 self.wrapper, "NeedSetupGnuPG", return_value=False 532 repo_wrapper, "NeedSetupGnuPG", return_value=False
509 ): 533 ):
510 self.assertTrue(self.wrapper.check_repo_verify(True)) 534 assert repo_wrapper.check_repo_verify(True)
511 535
512 def test_need_gpg_setup(self): 536 def test_need_gpg_setup(
537 self,
538 repo_wrapper: wrapper.Wrapper,
539 ) -> None:
513 """Should pass/fail based on gpg setup.""" 540 """Should pass/fail based on gpg setup."""
514 with mock.patch.object( 541 with mock.patch.object(
515 self.wrapper, "NeedSetupGnuPG", return_value=True 542 repo_wrapper, "NeedSetupGnuPG", return_value=True
516 ): 543 ):
517 with mock.patch.object(self.wrapper, "SetupGnuPG") as m: 544 with mock.patch.object(repo_wrapper, "SetupGnuPG") as m:
518 m.return_value = True 545 m.return_value = True
519 self.assertTrue(self.wrapper.check_repo_verify(True)) 546 assert repo_wrapper.check_repo_verify(True)
520 547
521 m.return_value = False 548 m.return_value = False
522 self.assertFalse(self.wrapper.check_repo_verify(True)) 549 assert not repo_wrapper.check_repo_verify(True)
523 550
524 551
525class CheckRepoRev(GitCheckoutTestCase): 552class TestCheckRepoRev:
526 """Check check_repo_rev behavior.""" 553 """Check check_repo_rev behavior."""
527 554
528 def test_verify_works(self): 555 def test_verify_works(
556 self,
557 repo_wrapper: wrapper.Wrapper,
558 git_checkout: GitCheckout,
559 ) -> None:
529 """Should pass when verification passes.""" 560 """Should pass when verification passes."""
530 with mock.patch.object( 561 with mock.patch.object(
531 self.wrapper, "check_repo_verify", return_value=True 562 repo_wrapper, "check_repo_verify", return_value=True
532 ): 563 ):
533 with mock.patch.object( 564 with mock.patch.object(
534 self.wrapper, "verify_rev", return_value="12345" 565 repo_wrapper, "verify_rev", return_value="12345"
535 ): 566 ):
536 rrev, lrev = self.wrapper.check_repo_rev(self.GIT_DIR, "stable") 567 rrev, lrev = repo_wrapper.check_repo_rev(
537 self.assertEqual("refs/heads/stable", rrev) 568 git_checkout.git_dir, "stable"
538 self.assertEqual("12345", lrev) 569 )
539 570 assert rrev == "refs/heads/stable"
540 def test_verify_fails(self): 571 assert lrev == "12345"
572
573 def test_verify_fails(
574 self,
575 repo_wrapper: wrapper.Wrapper,
576 git_checkout: GitCheckout,
577 ) -> None:
541 """Should fail when verification fails.""" 578 """Should fail when verification fails."""
542 with mock.patch.object( 579 with mock.patch.object(
543 self.wrapper, "check_repo_verify", return_value=True 580 repo_wrapper, "check_repo_verify", return_value=True
544 ): 581 ):
545 with mock.patch.object( 582 with mock.patch.object(
546 self.wrapper, "verify_rev", side_effect=Exception 583 repo_wrapper, "verify_rev", side_effect=RuntimeError
547 ): 584 ):
548 with self.assertRaises(Exception): 585 with pytest.raises(RuntimeError):
549 self.wrapper.check_repo_rev(self.GIT_DIR, "stable") 586 repo_wrapper.check_repo_rev(git_checkout.git_dir, "stable")
550 587
551 def test_verify_ignore(self): 588 def test_verify_ignore(
589 self,
590 repo_wrapper: wrapper.Wrapper,
591 git_checkout: GitCheckout,
592 ) -> None:
552 """Should pass when verification is disabled.""" 593 """Should pass when verification is disabled."""
553 with mock.patch.object( 594 with mock.patch.object(
554 self.wrapper, "verify_rev", side_effect=Exception 595 repo_wrapper, "verify_rev", side_effect=RuntimeError
555 ): 596 ):
556 rrev, lrev = self.wrapper.check_repo_rev( 597 rrev, lrev = repo_wrapper.check_repo_rev(
557 self.GIT_DIR, "stable", repo_verify=False 598 git_checkout.git_dir, "stable", repo_verify=False
558 ) 599 )
559 self.assertEqual("refs/heads/stable", rrev) 600 assert rrev == "refs/heads/stable"
560 self.assertEqual(self.REV_LIST[1], lrev) 601 assert lrev == git_checkout.rev_list[1]
diff --git a/tests/utils_for_test.py b/tests/utils_for_test.py
index f48613f34..3f1bed486 100644
--- a/tests/utils_for_test.py
+++ b/tests/utils_for_test.py
@@ -27,6 +27,11 @@ from typing import Optional, Union
27import git_command 27import git_command
28 28
29 29
30THIS_FILE = Path(__file__).resolve()
31THIS_DIR = THIS_FILE.parent
32FIXTURES_DIR = THIS_DIR / "fixtures"
33
34
30def init_git_tree( 35def init_git_tree(
31 path: Union[str, Path], 36 path: Union[str, Path],
32 ref_format: Optional[str] = None, 37 ref_format: Optional[str] = None,