diff options
| -rw-r--r-- | git_command.py | 12 | ||||
| -rw-r--r-- | git_config.py | 11 | ||||
| -rw-r--r-- | project.py | 4 | ||||
| -rw-r--r-- | ssh.py | 282 | ||||
| -rw-r--r-- | subcmds/sync.py | 44 | ||||
| -rw-r--r-- | tests/test_ssh.py | 30 |
6 files changed, 225 insertions, 158 deletions
diff --git a/git_command.py b/git_command.py index fabad0e0..04953f38 100644 --- a/git_command.py +++ b/git_command.py | |||
| @@ -21,7 +21,6 @@ from error import GitError | |||
| 21 | from git_refs import HEAD | 21 | from git_refs import HEAD |
| 22 | import platform_utils | 22 | import platform_utils |
| 23 | from repo_trace import REPO_TRACE, IsTrace, Trace | 23 | from repo_trace import REPO_TRACE, IsTrace, Trace |
| 24 | import ssh | ||
| 25 | from wrapper import Wrapper | 24 | from wrapper import Wrapper |
| 26 | 25 | ||
| 27 | GIT = 'git' | 26 | GIT = 'git' |
| @@ -167,7 +166,7 @@ class GitCommand(object): | |||
| 167 | capture_stderr=False, | 166 | capture_stderr=False, |
| 168 | merge_output=False, | 167 | merge_output=False, |
| 169 | disable_editor=False, | 168 | disable_editor=False, |
| 170 | ssh_proxy=False, | 169 | ssh_proxy=None, |
| 171 | cwd=None, | 170 | cwd=None, |
| 172 | gitdir=None): | 171 | gitdir=None): |
| 173 | env = self._GetBasicEnv() | 172 | env = self._GetBasicEnv() |
| @@ -175,8 +174,8 @@ class GitCommand(object): | |||
| 175 | if disable_editor: | 174 | if disable_editor: |
| 176 | env['GIT_EDITOR'] = ':' | 175 | env['GIT_EDITOR'] = ':' |
| 177 | if ssh_proxy: | 176 | if ssh_proxy: |
| 178 | env['REPO_SSH_SOCK'] = ssh.sock() | 177 | env['REPO_SSH_SOCK'] = ssh_proxy.sock() |
| 179 | env['GIT_SSH'] = ssh.proxy() | 178 | env['GIT_SSH'] = ssh_proxy.proxy |
| 180 | env['GIT_SSH_VARIANT'] = 'ssh' | 179 | env['GIT_SSH_VARIANT'] = 'ssh' |
| 181 | if 'http_proxy' in env and 'darwin' == sys.platform: | 180 | if 'http_proxy' in env and 'darwin' == sys.platform: |
| 182 | s = "'http.proxy=%s'" % (env['http_proxy'],) | 181 | s = "'http.proxy=%s'" % (env['http_proxy'],) |
| @@ -259,7 +258,7 @@ class GitCommand(object): | |||
| 259 | raise GitError('%s: %s' % (command[1], e)) | 258 | raise GitError('%s: %s' % (command[1], e)) |
| 260 | 259 | ||
| 261 | if ssh_proxy: | 260 | if ssh_proxy: |
| 262 | ssh.add_client(p) | 261 | ssh_proxy.add_client(p) |
| 263 | 262 | ||
| 264 | self.process = p | 263 | self.process = p |
| 265 | if input: | 264 | if input: |
| @@ -271,7 +270,8 @@ class GitCommand(object): | |||
| 271 | try: | 270 | try: |
| 272 | self.stdout, self.stderr = p.communicate() | 271 | self.stdout, self.stderr = p.communicate() |
| 273 | finally: | 272 | finally: |
| 274 | ssh.remove_client(p) | 273 | if ssh_proxy: |
| 274 | ssh_proxy.remove_client(p) | ||
| 275 | self.rc = p.wait() | 275 | self.rc = p.wait() |
| 276 | 276 | ||
| 277 | @staticmethod | 277 | @staticmethod |
diff --git a/git_config.py b/git_config.py index d7fef8ca..978f6a59 100644 --- a/git_config.py +++ b/git_config.py | |||
| @@ -27,7 +27,6 @@ import urllib.request | |||
| 27 | from error import GitError, UploadError | 27 | from error import GitError, UploadError |
| 28 | import platform_utils | 28 | import platform_utils |
| 29 | from repo_trace import Trace | 29 | from repo_trace import Trace |
| 30 | import ssh | ||
| 31 | from git_command import GitCommand | 30 | from git_command import GitCommand |
| 32 | from git_refs import R_CHANGES, R_HEADS, R_TAGS | 31 | from git_refs import R_CHANGES, R_HEADS, R_TAGS |
| 33 | 32 | ||
| @@ -519,17 +518,23 @@ class Remote(object): | |||
| 519 | 518 | ||
| 520 | return self.url.replace(longest, longestUrl, 1) | 519 | return self.url.replace(longest, longestUrl, 1) |
| 521 | 520 | ||
| 522 | def PreConnectFetch(self): | 521 | def PreConnectFetch(self, ssh_proxy): |
| 523 | """Run any setup for this remote before we connect to it. | 522 | """Run any setup for this remote before we connect to it. |
| 524 | 523 | ||
| 525 | In practice, if the remote is using SSH, we'll attempt to create a new | 524 | In practice, if the remote is using SSH, we'll attempt to create a new |
| 526 | SSH master session to it for reuse across projects. | 525 | SSH master session to it for reuse across projects. |
| 527 | 526 | ||
| 527 | Args: | ||
| 528 | ssh_proxy: The SSH settings for managing master sessions. | ||
| 529 | |||
| 528 | Returns: | 530 | Returns: |
| 529 | Whether the preconnect phase for this remote was successful. | 531 | Whether the preconnect phase for this remote was successful. |
| 530 | """ | 532 | """ |
| 533 | if not ssh_proxy: | ||
| 534 | return True | ||
| 535 | |||
| 531 | connectionUrl = self._InsteadOf() | 536 | connectionUrl = self._InsteadOf() |
| 532 | return ssh.preconnect(connectionUrl) | 537 | return ssh_proxy.preconnect(connectionUrl) |
| 533 | 538 | ||
| 534 | def ReviewUrl(self, userEmail, validate_certs): | 539 | def ReviewUrl(self, userEmail, validate_certs): |
| 535 | if self._review_url is None: | 540 | if self._review_url is None: |
| @@ -2045,8 +2045,8 @@ class Project(object): | |||
| 2045 | name = self.remote.name | 2045 | name = self.remote.name |
| 2046 | 2046 | ||
| 2047 | remote = self.GetRemote(name) | 2047 | remote = self.GetRemote(name) |
| 2048 | if not remote.PreConnectFetch(): | 2048 | if not remote.PreConnectFetch(ssh_proxy): |
| 2049 | ssh_proxy = False | 2049 | ssh_proxy = None |
| 2050 | 2050 | ||
| 2051 | if initial: | 2051 | if initial: |
| 2052 | if alt_dir and 'objects' == os.path.basename(alt_dir): | 2052 | if alt_dir and 'objects' == os.path.basename(alt_dir): |
| @@ -15,25 +15,20 @@ | |||
| 15 | """Common SSH management logic.""" | 15 | """Common SSH management logic.""" |
| 16 | 16 | ||
| 17 | import functools | 17 | import functools |
| 18 | import multiprocessing | ||
| 18 | import os | 19 | import os |
| 19 | import re | 20 | import re |
| 20 | import signal | 21 | import signal |
| 21 | import subprocess | 22 | import subprocess |
| 22 | import sys | 23 | import sys |
| 23 | import tempfile | 24 | import tempfile |
| 24 | try: | ||
| 25 | import threading as _threading | ||
| 26 | except ImportError: | ||
| 27 | import dummy_threading as _threading | ||
| 28 | import time | 25 | import time |
| 29 | 26 | ||
| 30 | import platform_utils | 27 | import platform_utils |
| 31 | from repo_trace import Trace | 28 | from repo_trace import Trace |
| 32 | 29 | ||
| 33 | 30 | ||
| 34 | _ssh_proxy_path = None | 31 | PROXY_PATH = os.path.join(os.path.dirname(__file__), 'git_ssh') |
| 35 | _ssh_sock_path = None | ||
| 36 | _ssh_clients = [] | ||
| 37 | 32 | ||
| 38 | 33 | ||
| 39 | def _run_ssh_version(): | 34 | def _run_ssh_version(): |
| @@ -62,68 +57,104 @@ def version(): | |||
| 62 | sys.exit(1) | 57 | sys.exit(1) |
| 63 | 58 | ||
| 64 | 59 | ||
| 65 | def proxy(): | 60 | URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):') |
| 66 | global _ssh_proxy_path | 61 | URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/') |
| 67 | if _ssh_proxy_path is None: | ||
| 68 | _ssh_proxy_path = os.path.join( | ||
| 69 | os.path.dirname(__file__), | ||
| 70 | 'git_ssh') | ||
| 71 | return _ssh_proxy_path | ||
| 72 | 62 | ||
| 73 | 63 | ||
| 74 | def add_client(p): | 64 | class ProxyManager: |
| 75 | _ssh_clients.append(p) | 65 | """Manage various ssh clients & masters that we spawn. |
| 76 | 66 | ||
| 67 | This will take care of sharing state between multiprocessing children, and | ||
| 68 | make sure that if we crash, we don't leak any of the ssh sessions. | ||
| 77 | 69 | ||
| 78 | def remove_client(p): | 70 | The code should work with a single-process scenario too, and not add too much |
| 79 | try: | 71 | overhead due to the manager. |
| 80 | _ssh_clients.remove(p) | 72 | """ |
| 81 | except ValueError: | ||
| 82 | pass | ||
| 83 | |||
| 84 | 73 | ||
| 85 | def _terminate_clients(): | 74 | # Path to the ssh program to run which will pass our master settings along. |
| 86 | global _ssh_clients | 75 | # Set here more as a convenience API. |
| 87 | for p in _ssh_clients: | 76 | proxy = PROXY_PATH |
| 77 | |||
| 78 | def __init__(self, manager): | ||
| 79 | # Protect access to the list of active masters. | ||
| 80 | self._lock = multiprocessing.Lock() | ||
| 81 | # List of active masters (pid). These will be spawned on demand, and we are | ||
| 82 | # responsible for shutting them all down at the end. | ||
| 83 | self._masters = manager.list() | ||
| 84 | # Set of active masters indexed by "host:port" information. | ||
| 85 | # The value isn't used, but multiprocessing doesn't provide a set class. | ||
| 86 | self._master_keys = manager.dict() | ||
| 87 | # Whether ssh masters are known to be broken, so we give up entirely. | ||
| 88 | self._master_broken = manager.Value('b', False) | ||
| 89 | # List of active ssh sesssions. Clients will be added & removed as | ||
| 90 | # connections finish, so this list is just for safety & cleanup if we crash. | ||
| 91 | self._clients = manager.list() | ||
| 92 | # Path to directory for holding master sockets. | ||
| 93 | self._sock_path = None | ||
| 94 | |||
| 95 | def __enter__(self): | ||
| 96 | """Enter a new context.""" | ||
| 97 | return self | ||
| 98 | |||
| 99 | def __exit__(self, exc_type, exc_value, traceback): | ||
| 100 | """Exit a context & clean up all resources.""" | ||
| 101 | self.close() | ||
| 102 | |||
| 103 | def add_client(self, proc): | ||
| 104 | """Track a new ssh session.""" | ||
| 105 | self._clients.append(proc.pid) | ||
| 106 | |||
| 107 | def remove_client(self, proc): | ||
| 108 | """Remove a completed ssh session.""" | ||
| 88 | try: | 109 | try: |
| 89 | os.kill(p.pid, signal.SIGTERM) | 110 | self._clients.remove(proc.pid) |
| 90 | p.wait() | 111 | except ValueError: |
| 91 | except OSError: | ||
| 92 | pass | 112 | pass |
| 93 | _ssh_clients = [] | ||
| 94 | |||
| 95 | |||
| 96 | _master_processes = [] | ||
| 97 | _master_keys = set() | ||
| 98 | _ssh_master = True | ||
| 99 | _master_keys_lock = None | ||
| 100 | |||
| 101 | |||
| 102 | def init(): | ||
| 103 | """Should be called once at the start of repo to init ssh master handling. | ||
| 104 | |||
| 105 | At the moment, all we do is to create our lock. | ||
| 106 | """ | ||
| 107 | global _master_keys_lock | ||
| 108 | assert _master_keys_lock is None, "Should only call init once" | ||
| 109 | _master_keys_lock = _threading.Lock() | ||
| 110 | |||
| 111 | |||
| 112 | def _open_ssh(host, port=None): | ||
| 113 | global _ssh_master | ||
| 114 | |||
| 115 | # Bail before grabbing the lock if we already know that we aren't going to | ||
| 116 | # try creating new masters below. | ||
| 117 | if sys.platform in ('win32', 'cygwin'): | ||
| 118 | return False | ||
| 119 | |||
| 120 | # Acquire the lock. This is needed to prevent opening multiple masters for | ||
| 121 | # the same host when we're running "repo sync -jN" (for N > 1) _and_ the | ||
| 122 | # manifest <remote fetch="ssh://xyz"> specifies a different host from the | ||
| 123 | # one that was passed to repo init. | ||
| 124 | _master_keys_lock.acquire() | ||
| 125 | try: | ||
| 126 | 113 | ||
| 114 | def add_master(self, proc): | ||
| 115 | """Track a new master connection.""" | ||
| 116 | self._masters.append(proc.pid) | ||
| 117 | |||
| 118 | def _terminate(self, procs): | ||
| 119 | """Kill all |procs|.""" | ||
| 120 | for pid in procs: | ||
| 121 | try: | ||
| 122 | os.kill(pid, signal.SIGTERM) | ||
| 123 | os.waitpid(pid, 0) | ||
| 124 | except OSError: | ||
| 125 | pass | ||
| 126 | |||
| 127 | # The multiprocessing.list() API doesn't provide many standard list() | ||
| 128 | # methods, so we have to manually clear the list. | ||
| 129 | while True: | ||
| 130 | try: | ||
| 131 | procs.pop(0) | ||
| 132 | except: | ||
| 133 | break | ||
| 134 | |||
| 135 | def close(self): | ||
| 136 | """Close this active ssh session. | ||
| 137 | |||
| 138 | Kill all ssh clients & masters we created, and nuke the socket dir. | ||
| 139 | """ | ||
| 140 | self._terminate(self._clients) | ||
| 141 | self._terminate(self._masters) | ||
| 142 | |||
| 143 | d = self.sock(create=False) | ||
| 144 | if d: | ||
| 145 | try: | ||
| 146 | platform_utils.rmdir(os.path.dirname(d)) | ||
| 147 | except OSError: | ||
| 148 | pass | ||
| 149 | |||
| 150 | def _open_unlocked(self, host, port=None): | ||
| 151 | """Make sure a ssh master session exists for |host| & |port|. | ||
| 152 | |||
| 153 | If one doesn't exist already, we'll create it. | ||
| 154 | |||
| 155 | We won't grab any locks, so the caller has to do that. This helps keep the | ||
| 156 | business logic of actually creating the master separate from grabbing locks. | ||
| 157 | """ | ||
| 127 | # Check to see whether we already think that the master is running; if we | 158 | # Check to see whether we already think that the master is running; if we |
| 128 | # think it's already running, return right away. | 159 | # think it's already running, return right away. |
| 129 | if port is not None: | 160 | if port is not None: |
| @@ -131,17 +162,15 @@ def _open_ssh(host, port=None): | |||
| 131 | else: | 162 | else: |
| 132 | key = host | 163 | key = host |
| 133 | 164 | ||
| 134 | if key in _master_keys: | 165 | if key in self._master_keys: |
| 135 | return True | 166 | return True |
| 136 | 167 | ||
| 137 | if not _ssh_master or 'GIT_SSH' in os.environ: | 168 | if self._master_broken.value or 'GIT_SSH' in os.environ: |
| 138 | # Failed earlier, so don't retry. | 169 | # Failed earlier, so don't retry. |
| 139 | return False | 170 | return False |
| 140 | 171 | ||
| 141 | # We will make two calls to ssh; this is the common part of both calls. | 172 | # We will make two calls to ssh; this is the common part of both calls. |
| 142 | command_base = ['ssh', | 173 | command_base = ['ssh', '-o', 'ControlPath %s' % self.sock(), host] |
| 143 | '-o', 'ControlPath %s' % sock(), | ||
| 144 | host] | ||
| 145 | if port is not None: | 174 | if port is not None: |
| 146 | command_base[1:1] = ['-p', str(port)] | 175 | command_base[1:1] = ['-p', str(port)] |
| 147 | 176 | ||
| @@ -161,7 +190,7 @@ def _open_ssh(host, port=None): | |||
| 161 | if not isnt_running: | 190 | if not isnt_running: |
| 162 | # Our double-check found that the master _was_ infact running. Add to | 191 | # Our double-check found that the master _was_ infact running. Add to |
| 163 | # the list of keys. | 192 | # the list of keys. |
| 164 | _master_keys.add(key) | 193 | self._master_keys[key] = True |
| 165 | return True | 194 | return True |
| 166 | except Exception: | 195 | except Exception: |
| 167 | # Ignore excpetions. We we will fall back to the normal command and print | 196 | # Ignore excpetions. We we will fall back to the normal command and print |
| @@ -173,7 +202,7 @@ def _open_ssh(host, port=None): | |||
| 173 | Trace(': %s', ' '.join(command)) | 202 | Trace(': %s', ' '.join(command)) |
| 174 | p = subprocess.Popen(command) | 203 | p = subprocess.Popen(command) |
| 175 | except Exception as e: | 204 | except Exception as e: |
| 176 | _ssh_master = False | 205 | self._master_broken.value = True |
| 177 | print('\nwarn: cannot enable ssh control master for %s:%s\n%s' | 206 | print('\nwarn: cannot enable ssh control master for %s:%s\n%s' |
| 178 | % (host, port, str(e)), file=sys.stderr) | 207 | % (host, port, str(e)), file=sys.stderr) |
| 179 | return False | 208 | return False |
| @@ -183,75 +212,66 @@ def _open_ssh(host, port=None): | |||
| 183 | if ssh_died: | 212 | if ssh_died: |
| 184 | return False | 213 | return False |
| 185 | 214 | ||
| 186 | _master_processes.append(p) | 215 | self.add_master(p) |
| 187 | _master_keys.add(key) | 216 | self._master_keys[key] = True |
| 188 | return True | 217 | return True |
| 189 | finally: | ||
| 190 | _master_keys_lock.release() | ||
| 191 | 218 | ||
| 219 | def _open(self, host, port=None): | ||
| 220 | """Make sure a ssh master session exists for |host| & |port|. | ||
| 192 | 221 | ||
| 193 | def close(): | 222 | If one doesn't exist already, we'll create it. |
| 194 | global _master_keys_lock | ||
| 195 | |||
| 196 | _terminate_clients() | ||
| 197 | |||
| 198 | for p in _master_processes: | ||
| 199 | try: | ||
| 200 | os.kill(p.pid, signal.SIGTERM) | ||
| 201 | p.wait() | ||
| 202 | except OSError: | ||
| 203 | pass | ||
| 204 | del _master_processes[:] | ||
| 205 | _master_keys.clear() | ||
| 206 | |||
| 207 | d = sock(create=False) | ||
| 208 | if d: | ||
| 209 | try: | ||
| 210 | platform_utils.rmdir(os.path.dirname(d)) | ||
| 211 | except OSError: | ||
| 212 | pass | ||
| 213 | |||
| 214 | # We're done with the lock, so we can delete it. | ||
| 215 | _master_keys_lock = None | ||
| 216 | 223 | ||
| 224 | This will obtain any necessary locks to avoid inter-process races. | ||
| 225 | """ | ||
| 226 | # Bail before grabbing the lock if we already know that we aren't going to | ||
| 227 | # try creating new masters below. | ||
| 228 | if sys.platform in ('win32', 'cygwin'): | ||
| 229 | return False | ||
| 217 | 230 | ||
| 218 | URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):') | 231 | # Acquire the lock. This is needed to prevent opening multiple masters for |
| 219 | URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/') | 232 | # the same host when we're running "repo sync -jN" (for N > 1) _and_ the |
| 233 | # manifest <remote fetch="ssh://xyz"> specifies a different host from the | ||
| 234 | # one that was passed to repo init. | ||
| 235 | with self._lock: | ||
| 236 | return self._open_unlocked(host, port) | ||
| 237 | |||
| 238 | def preconnect(self, url): | ||
| 239 | """If |uri| will create a ssh connection, setup the ssh master for it.""" | ||
| 240 | m = URI_ALL.match(url) | ||
| 241 | if m: | ||
| 242 | scheme = m.group(1) | ||
| 243 | host = m.group(2) | ||
| 244 | if ':' in host: | ||
| 245 | host, port = host.split(':') | ||
| 246 | else: | ||
| 247 | port = None | ||
| 248 | if scheme in ('ssh', 'git+ssh', 'ssh+git'): | ||
| 249 | return self._open(host, port) | ||
| 250 | return False | ||
| 220 | 251 | ||
| 252 | m = URI_SCP.match(url) | ||
| 253 | if m: | ||
| 254 | host = m.group(1) | ||
| 255 | return self._open(host) | ||
| 221 | 256 | ||
| 222 | def preconnect(url): | ||
| 223 | m = URI_ALL.match(url) | ||
| 224 | if m: | ||
| 225 | scheme = m.group(1) | ||
| 226 | host = m.group(2) | ||
| 227 | if ':' in host: | ||
| 228 | host, port = host.split(':') | ||
| 229 | else: | ||
| 230 | port = None | ||
| 231 | if scheme in ('ssh', 'git+ssh', 'ssh+git'): | ||
| 232 | return _open_ssh(host, port) | ||
| 233 | return False | 257 | return False |
| 234 | 258 | ||
| 235 | m = URI_SCP.match(url) | 259 | def sock(self, create=True): |
| 236 | if m: | 260 | """Return the path to the ssh socket dir. |
| 237 | host = m.group(1) | 261 | |
| 238 | return _open_ssh(host) | 262 | This has all the master sockets so clients can talk to them. |
| 239 | 263 | """ | |
| 240 | return False | 264 | if self._sock_path is None: |
| 241 | 265 | if not create: | |
| 242 | def sock(create=True): | 266 | return None |
| 243 | global _ssh_sock_path | 267 | tmp_dir = '/tmp' |
| 244 | if _ssh_sock_path is None: | 268 | if not os.path.exists(tmp_dir): |
| 245 | if not create: | 269 | tmp_dir = tempfile.gettempdir() |
| 246 | return None | 270 | if version() < (6, 7): |
| 247 | tmp_dir = '/tmp' | 271 | tokens = '%r@%h:%p' |
| 248 | if not os.path.exists(tmp_dir): | 272 | else: |
| 249 | tmp_dir = tempfile.gettempdir() | 273 | tokens = '%C' # hash of %l%h%p%r |
| 250 | if version() < (6, 7): | 274 | self._sock_path = os.path.join( |
| 251 | tokens = '%r@%h:%p' | 275 | tempfile.mkdtemp('', 'ssh-', tmp_dir), |
| 252 | else: | 276 | 'master-' + tokens) |
| 253 | tokens = '%C' # hash of %l%h%p%r | 277 | return self._sock_path |
| 254 | _ssh_sock_path = os.path.join( | ||
| 255 | tempfile.mkdtemp('', 'ssh-', tmp_dir), | ||
| 256 | 'master-' + tokens) | ||
| 257 | return _ssh_sock_path | ||
diff --git a/subcmds/sync.py b/subcmds/sync.py index 28568062..fb25c221 100644 --- a/subcmds/sync.py +++ b/subcmds/sync.py | |||
| @@ -358,7 +358,7 @@ later is required to fix a server side protocol bug. | |||
| 358 | optimized_fetch=opt.optimized_fetch, | 358 | optimized_fetch=opt.optimized_fetch, |
| 359 | retry_fetches=opt.retry_fetches, | 359 | retry_fetches=opt.retry_fetches, |
| 360 | prune=opt.prune, | 360 | prune=opt.prune, |
| 361 | ssh_proxy=True, | 361 | ssh_proxy=self.ssh_proxy, |
| 362 | clone_filter=self.manifest.CloneFilter, | 362 | clone_filter=self.manifest.CloneFilter, |
| 363 | partial_clone_exclude=self.manifest.PartialCloneExclude) | 363 | partial_clone_exclude=self.manifest.PartialCloneExclude) |
| 364 | 364 | ||
| @@ -380,7 +380,11 @@ later is required to fix a server side protocol bug. | |||
| 380 | finish = time.time() | 380 | finish = time.time() |
| 381 | return (success, project, start, finish) | 381 | return (success, project, start, finish) |
| 382 | 382 | ||
| 383 | def _Fetch(self, projects, opt, err_event): | 383 | @classmethod |
| 384 | def _FetchInitChild(cls, ssh_proxy): | ||
| 385 | cls.ssh_proxy = ssh_proxy | ||
| 386 | |||
| 387 | def _Fetch(self, projects, opt, err_event, ssh_proxy): | ||
| 384 | ret = True | 388 | ret = True |
| 385 | 389 | ||
| 386 | jobs = opt.jobs_network if opt.jobs_network else self.jobs | 390 | jobs = opt.jobs_network if opt.jobs_network else self.jobs |
| @@ -410,8 +414,14 @@ later is required to fix a server side protocol bug. | |||
| 410 | break | 414 | break |
| 411 | return ret | 415 | return ret |
| 412 | 416 | ||
| 417 | # We pass the ssh proxy settings via the class. This allows multiprocessing | ||
| 418 | # to pickle it up when spawning children. We can't pass it as an argument | ||
| 419 | # to _FetchProjectList below as multiprocessing is unable to pickle those. | ||
| 420 | Sync.ssh_proxy = None | ||
| 421 | |||
| 413 | # NB: Multiprocessing is heavy, so don't spin it up for one job. | 422 | # NB: Multiprocessing is heavy, so don't spin it up for one job. |
| 414 | if len(projects_list) == 1 or jobs == 1: | 423 | if len(projects_list) == 1 or jobs == 1: |
| 424 | self._FetchInitChild(ssh_proxy) | ||
| 415 | if not _ProcessResults(self._FetchProjectList(opt, x) for x in projects_list): | 425 | if not _ProcessResults(self._FetchProjectList(opt, x) for x in projects_list): |
| 416 | ret = False | 426 | ret = False |
| 417 | else: | 427 | else: |
| @@ -429,7 +439,8 @@ later is required to fix a server side protocol bug. | |||
| 429 | else: | 439 | else: |
| 430 | pm.update(inc=0, msg='warming up') | 440 | pm.update(inc=0, msg='warming up') |
| 431 | chunksize = 4 | 441 | chunksize = 4 |
| 432 | with multiprocessing.Pool(jobs) as pool: | 442 | with multiprocessing.Pool( |
| 443 | jobs, initializer=self._FetchInitChild, initargs=(ssh_proxy,)) as pool: | ||
| 433 | results = pool.imap_unordered( | 444 | results = pool.imap_unordered( |
| 434 | functools.partial(self._FetchProjectList, opt), | 445 | functools.partial(self._FetchProjectList, opt), |
| 435 | projects_list, | 446 | projects_list, |
| @@ -438,6 +449,11 @@ later is required to fix a server side protocol bug. | |||
| 438 | ret = False | 449 | ret = False |
| 439 | pool.close() | 450 | pool.close() |
| 440 | 451 | ||
| 452 | # Cleanup the reference now that we're done with it, and we're going to | ||
| 453 | # release any resources it points to. If we don't, later multiprocessing | ||
| 454 | # usage (e.g. checkouts) will try to pickle and then crash. | ||
| 455 | del Sync.ssh_proxy | ||
| 456 | |||
| 441 | pm.end() | 457 | pm.end() |
| 442 | self._fetch_times.Save() | 458 | self._fetch_times.Save() |
| 443 | 459 | ||
| @@ -447,7 +463,7 @@ later is required to fix a server side protocol bug. | |||
| 447 | return (ret, fetched) | 463 | return (ret, fetched) |
| 448 | 464 | ||
| 449 | def _FetchMain(self, opt, args, all_projects, err_event, manifest_name, | 465 | def _FetchMain(self, opt, args, all_projects, err_event, manifest_name, |
| 450 | load_local_manifests): | 466 | load_local_manifests, ssh_proxy): |
| 451 | """The main network fetch loop. | 467 | """The main network fetch loop. |
| 452 | 468 | ||
| 453 | Args: | 469 | Args: |
| @@ -457,6 +473,7 @@ later is required to fix a server side protocol bug. | |||
| 457 | err_event: Whether an error was hit while processing. | 473 | err_event: Whether an error was hit while processing. |
| 458 | manifest_name: Manifest file to be reloaded. | 474 | manifest_name: Manifest file to be reloaded. |
| 459 | load_local_manifests: Whether to load local manifests. | 475 | load_local_manifests: Whether to load local manifests. |
| 476 | ssh_proxy: SSH manager for clients & masters. | ||
| 460 | """ | 477 | """ |
| 461 | rp = self.manifest.repoProject | 478 | rp = self.manifest.repoProject |
| 462 | 479 | ||
| @@ -467,7 +484,7 @@ later is required to fix a server side protocol bug. | |||
| 467 | to_fetch.extend(all_projects) | 484 | to_fetch.extend(all_projects) |
| 468 | to_fetch.sort(key=self._fetch_times.Get, reverse=True) | 485 | to_fetch.sort(key=self._fetch_times.Get, reverse=True) |
| 469 | 486 | ||
| 470 | success, fetched = self._Fetch(to_fetch, opt, err_event) | 487 | success, fetched = self._Fetch(to_fetch, opt, err_event, ssh_proxy) |
| 471 | if not success: | 488 | if not success: |
| 472 | err_event.set() | 489 | err_event.set() |
| 473 | 490 | ||
| @@ -498,7 +515,7 @@ later is required to fix a server side protocol bug. | |||
| 498 | if previously_missing_set == missing_set: | 515 | if previously_missing_set == missing_set: |
| 499 | break | 516 | break |
| 500 | previously_missing_set = missing_set | 517 | previously_missing_set = missing_set |
| 501 | success, new_fetched = self._Fetch(missing, opt, err_event) | 518 | success, new_fetched = self._Fetch(missing, opt, err_event, ssh_proxy) |
| 502 | if not success: | 519 | if not success: |
| 503 | err_event.set() | 520 | err_event.set() |
| 504 | fetched.update(new_fetched) | 521 | fetched.update(new_fetched) |
| @@ -985,12 +1002,15 @@ later is required to fix a server side protocol bug. | |||
| 985 | 1002 | ||
| 986 | self._fetch_times = _FetchTimes(self.manifest) | 1003 | self._fetch_times = _FetchTimes(self.manifest) |
| 987 | if not opt.local_only: | 1004 | if not opt.local_only: |
| 988 | try: | 1005 | with multiprocessing.Manager() as manager: |
| 989 | ssh.init() | 1006 | with ssh.ProxyManager(manager) as ssh_proxy: |
| 990 | self._FetchMain(opt, args, all_projects, err_event, manifest_name, | 1007 | # Initialize the socket dir once in the parent. |
| 991 | load_local_manifests) | 1008 | ssh_proxy.sock() |
| 992 | finally: | 1009 | self._FetchMain(opt, args, all_projects, err_event, manifest_name, |
| 993 | ssh.close() | 1010 | load_local_manifests, ssh_proxy) |
| 1011 | |||
| 1012 | if opt.network_only: | ||
| 1013 | return | ||
| 994 | 1014 | ||
| 995 | # If we saw an error, exit with code 1 so that other scripts can check. | 1015 | # If we saw an error, exit with code 1 so that other scripts can check. |
| 996 | if err_event.is_set(): | 1016 | if err_event.is_set(): |
diff --git a/tests/test_ssh.py b/tests/test_ssh.py index 5a4f27e4..ffb5cb94 100644 --- a/tests/test_ssh.py +++ b/tests/test_ssh.py | |||
| @@ -14,6 +14,8 @@ | |||
| 14 | 14 | ||
| 15 | """Unittests for the ssh.py module.""" | 15 | """Unittests for the ssh.py module.""" |
| 16 | 16 | ||
| 17 | import multiprocessing | ||
| 18 | import subprocess | ||
| 17 | import unittest | 19 | import unittest |
| 18 | from unittest import mock | 20 | from unittest import mock |
| 19 | 21 | ||
| @@ -39,14 +41,34 @@ class SshTests(unittest.TestCase): | |||
| 39 | with mock.patch('ssh._run_ssh_version', return_value='OpenSSH_1.2\n'): | 41 | with mock.patch('ssh._run_ssh_version', return_value='OpenSSH_1.2\n'): |
| 40 | self.assertEqual(ssh.version(), (1, 2)) | 42 | self.assertEqual(ssh.version(), (1, 2)) |
| 41 | 43 | ||
| 44 | def test_context_manager_empty(self): | ||
| 45 | """Verify context manager with no clients works correctly.""" | ||
| 46 | with multiprocessing.Manager() as manager: | ||
| 47 | with ssh.ProxyManager(manager): | ||
| 48 | pass | ||
| 49 | |||
| 50 | def test_context_manager_child_cleanup(self): | ||
| 51 | """Verify orphaned clients & masters get cleaned up.""" | ||
| 52 | with multiprocessing.Manager() as manager: | ||
| 53 | with ssh.ProxyManager(manager) as ssh_proxy: | ||
| 54 | client = subprocess.Popen(['sleep', '964853320']) | ||
| 55 | ssh_proxy.add_client(client) | ||
| 56 | master = subprocess.Popen(['sleep', '964853321']) | ||
| 57 | ssh_proxy.add_master(master) | ||
| 58 | # If the process still exists, these will throw timeout errors. | ||
| 59 | client.wait(0) | ||
| 60 | master.wait(0) | ||
| 61 | |||
| 42 | def test_ssh_sock(self): | 62 | def test_ssh_sock(self): |
| 43 | """Check sock() function.""" | 63 | """Check sock() function.""" |
| 64 | manager = multiprocessing.Manager() | ||
| 65 | proxy = ssh.ProxyManager(manager) | ||
| 44 | with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'): | 66 | with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'): |
| 45 | # old ssh version uses port | 67 | # old ssh version uses port |
| 46 | with mock.patch('ssh.version', return_value=(6, 6)): | 68 | with mock.patch('ssh.version', return_value=(6, 6)): |
| 47 | self.assertTrue(ssh.sock().endswith('%p')) | 69 | self.assertTrue(proxy.sock().endswith('%p')) |
| 48 | ssh._ssh_sock_path = None | 70 | |
| 71 | proxy._sock_path = None | ||
| 49 | # new ssh version uses hash | 72 | # new ssh version uses hash |
| 50 | with mock.patch('ssh.version', return_value=(6, 7)): | 73 | with mock.patch('ssh.version', return_value=(6, 7)): |
| 51 | self.assertTrue(ssh.sock().endswith('%C')) | 74 | self.assertTrue(proxy.sock().endswith('%C')) |
| 52 | ssh._ssh_sock_path = None | ||
