diff options
| -rw-r--r-- | git_command.py | 91 | ||||
| -rw-r--r-- | git_config.py | 156 | ||||
| -rwxr-xr-x | main.py | 7 | ||||
| -rw-r--r-- | ssh.py | 257 | ||||
| -rw-r--r-- | tests/test_git_command.py | 32 | ||||
| -rw-r--r-- | tests/test_ssh.py | 52 |
6 files changed, 320 insertions, 275 deletions
diff --git a/git_command.py b/git_command.py index f8cb280c..fabad0e0 100644 --- a/git_command.py +++ b/git_command.py | |||
| @@ -14,16 +14,14 @@ | |||
| 14 | 14 | ||
| 15 | import functools | 15 | import functools |
| 16 | import os | 16 | import os |
| 17 | import re | ||
| 18 | import sys | 17 | import sys |
| 19 | import subprocess | 18 | import subprocess |
| 20 | import tempfile | ||
| 21 | from signal import SIGTERM | ||
| 22 | 19 | ||
| 23 | from error import GitError | 20 | from error import GitError |
| 24 | from git_refs import HEAD | 21 | from git_refs import HEAD |
| 25 | import platform_utils | 22 | import platform_utils |
| 26 | from repo_trace import REPO_TRACE, IsTrace, Trace | 23 | from repo_trace import REPO_TRACE, IsTrace, Trace |
| 24 | import ssh | ||
| 27 | from wrapper import Wrapper | 25 | from wrapper import Wrapper |
| 28 | 26 | ||
| 29 | GIT = 'git' | 27 | GIT = 'git' |
| @@ -43,85 +41,6 @@ GIT_DIR = 'GIT_DIR' | |||
| 43 | LAST_GITDIR = None | 41 | LAST_GITDIR = None |
| 44 | LAST_CWD = None | 42 | LAST_CWD = None |
| 45 | 43 | ||
| 46 | _ssh_proxy_path = None | ||
| 47 | _ssh_sock_path = None | ||
| 48 | _ssh_clients = [] | ||
| 49 | |||
| 50 | |||
| 51 | def _run_ssh_version(): | ||
| 52 | """run ssh -V to display the version number""" | ||
| 53 | return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode() | ||
| 54 | |||
| 55 | |||
| 56 | def _parse_ssh_version(ver_str=None): | ||
| 57 | """parse a ssh version string into a tuple""" | ||
| 58 | if ver_str is None: | ||
| 59 | ver_str = _run_ssh_version() | ||
| 60 | m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str) | ||
| 61 | if m: | ||
| 62 | return tuple(int(x) for x in m.group(1).split('.')) | ||
| 63 | else: | ||
| 64 | return () | ||
| 65 | |||
| 66 | |||
| 67 | @functools.lru_cache(maxsize=None) | ||
| 68 | def ssh_version(): | ||
| 69 | """return ssh version as a tuple""" | ||
| 70 | try: | ||
| 71 | return _parse_ssh_version() | ||
| 72 | except subprocess.CalledProcessError: | ||
| 73 | print('fatal: unable to detect ssh version', file=sys.stderr) | ||
| 74 | sys.exit(1) | ||
| 75 | |||
| 76 | |||
| 77 | def ssh_sock(create=True): | ||
| 78 | global _ssh_sock_path | ||
| 79 | if _ssh_sock_path is None: | ||
| 80 | if not create: | ||
| 81 | return None | ||
| 82 | tmp_dir = '/tmp' | ||
| 83 | if not os.path.exists(tmp_dir): | ||
| 84 | tmp_dir = tempfile.gettempdir() | ||
| 85 | if ssh_version() < (6, 7): | ||
| 86 | tokens = '%r@%h:%p' | ||
| 87 | else: | ||
| 88 | tokens = '%C' # hash of %l%h%p%r | ||
| 89 | _ssh_sock_path = os.path.join( | ||
| 90 | tempfile.mkdtemp('', 'ssh-', tmp_dir), | ||
| 91 | 'master-' + tokens) | ||
| 92 | return _ssh_sock_path | ||
| 93 | |||
| 94 | |||
| 95 | def _ssh_proxy(): | ||
| 96 | global _ssh_proxy_path | ||
| 97 | if _ssh_proxy_path is None: | ||
| 98 | _ssh_proxy_path = os.path.join( | ||
| 99 | os.path.dirname(__file__), | ||
| 100 | 'git_ssh') | ||
| 101 | return _ssh_proxy_path | ||
| 102 | |||
| 103 | |||
| 104 | def _add_ssh_client(p): | ||
| 105 | _ssh_clients.append(p) | ||
| 106 | |||
| 107 | |||
| 108 | def _remove_ssh_client(p): | ||
| 109 | try: | ||
| 110 | _ssh_clients.remove(p) | ||
| 111 | except ValueError: | ||
| 112 | pass | ||
| 113 | |||
| 114 | |||
| 115 | def terminate_ssh_clients(): | ||
| 116 | global _ssh_clients | ||
| 117 | for p in _ssh_clients: | ||
| 118 | try: | ||
| 119 | os.kill(p.pid, SIGTERM) | ||
| 120 | p.wait() | ||
| 121 | except OSError: | ||
| 122 | pass | ||
| 123 | _ssh_clients = [] | ||
| 124 | |||
| 125 | 44 | ||
| 126 | class _GitCall(object): | 45 | class _GitCall(object): |
| 127 | @functools.lru_cache(maxsize=None) | 46 | @functools.lru_cache(maxsize=None) |
| @@ -256,8 +175,8 @@ class GitCommand(object): | |||
| 256 | if disable_editor: | 175 | if disable_editor: |
| 257 | env['GIT_EDITOR'] = ':' | 176 | env['GIT_EDITOR'] = ':' |
| 258 | if ssh_proxy: | 177 | if ssh_proxy: |
| 259 | env['REPO_SSH_SOCK'] = ssh_sock() | 178 | env['REPO_SSH_SOCK'] = ssh.sock() |
| 260 | env['GIT_SSH'] = _ssh_proxy() | 179 | env['GIT_SSH'] = ssh.proxy() |
| 261 | env['GIT_SSH_VARIANT'] = 'ssh' | 180 | env['GIT_SSH_VARIANT'] = 'ssh' |
| 262 | if 'http_proxy' in env and 'darwin' == sys.platform: | 181 | if 'http_proxy' in env and 'darwin' == sys.platform: |
| 263 | s = "'http.proxy=%s'" % (env['http_proxy'],) | 182 | s = "'http.proxy=%s'" % (env['http_proxy'],) |
| @@ -340,7 +259,7 @@ class GitCommand(object): | |||
| 340 | raise GitError('%s: %s' % (command[1], e)) | 259 | raise GitError('%s: %s' % (command[1], e)) |
| 341 | 260 | ||
| 342 | if ssh_proxy: | 261 | if ssh_proxy: |
| 343 | _add_ssh_client(p) | 262 | ssh.add_client(p) |
| 344 | 263 | ||
| 345 | self.process = p | 264 | self.process = p |
| 346 | if input: | 265 | if input: |
| @@ -352,7 +271,7 @@ class GitCommand(object): | |||
| 352 | try: | 271 | try: |
| 353 | self.stdout, self.stderr = p.communicate() | 272 | self.stdout, self.stderr = p.communicate() |
| 354 | finally: | 273 | finally: |
| 355 | _remove_ssh_client(p) | 274 | ssh.remove_client(p) |
| 356 | self.rc = p.wait() | 275 | self.rc = p.wait() |
| 357 | 276 | ||
| 358 | @staticmethod | 277 | @staticmethod |
diff --git a/git_config.py b/git_config.py index fcd0446c..1d8d1363 100644 --- a/git_config.py +++ b/git_config.py | |||
| @@ -18,25 +18,17 @@ from http.client import HTTPException | |||
| 18 | import json | 18 | import json |
| 19 | import os | 19 | import os |
| 20 | import re | 20 | import re |
| 21 | import signal | ||
| 22 | import ssl | 21 | import ssl |
| 23 | import subprocess | 22 | import subprocess |
| 24 | import sys | 23 | import sys |
| 25 | try: | ||
| 26 | import threading as _threading | ||
| 27 | except ImportError: | ||
| 28 | import dummy_threading as _threading | ||
| 29 | import time | ||
| 30 | import urllib.error | 24 | import urllib.error |
| 31 | import urllib.request | 25 | import urllib.request |
| 32 | 26 | ||
| 33 | from error import GitError, UploadError | 27 | from error import GitError, UploadError |
| 34 | import platform_utils | 28 | import platform_utils |
| 35 | from repo_trace import Trace | 29 | from repo_trace import Trace |
| 36 | 30 | import ssh | |
| 37 | from git_command import GitCommand | 31 | from git_command import GitCommand |
| 38 | from git_command import ssh_sock | ||
| 39 | from git_command import terminate_ssh_clients | ||
| 40 | from git_refs import R_CHANGES, R_HEADS, R_TAGS | 32 | from git_refs import R_CHANGES, R_HEADS, R_TAGS |
| 41 | 33 | ||
| 42 | ID_RE = re.compile(r'^[0-9a-f]{40}$') | 34 | ID_RE = re.compile(r'^[0-9a-f]{40}$') |
| @@ -440,129 +432,6 @@ class RefSpec(object): | |||
| 440 | return s | 432 | return s |
| 441 | 433 | ||
| 442 | 434 | ||
| 443 | _master_processes = [] | ||
| 444 | _master_keys = set() | ||
| 445 | _ssh_master = True | ||
| 446 | _master_keys_lock = None | ||
| 447 | |||
| 448 | |||
| 449 | def init_ssh(): | ||
| 450 | """Should be called once at the start of repo to init ssh master handling. | ||
| 451 | |||
| 452 | At the moment, all we do is to create our lock. | ||
| 453 | """ | ||
| 454 | global _master_keys_lock | ||
| 455 | assert _master_keys_lock is None, "Should only call init_ssh once" | ||
| 456 | _master_keys_lock = _threading.Lock() | ||
| 457 | |||
| 458 | |||
| 459 | def _open_ssh(host, port=None): | ||
| 460 | global _ssh_master | ||
| 461 | |||
| 462 | # Bail before grabbing the lock if we already know that we aren't going to | ||
| 463 | # try creating new masters below. | ||
| 464 | if sys.platform in ('win32', 'cygwin'): | ||
| 465 | return False | ||
| 466 | |||
| 467 | # Acquire the lock. This is needed to prevent opening multiple masters for | ||
| 468 | # the same host when we're running "repo sync -jN" (for N > 1) _and_ the | ||
| 469 | # manifest <remote fetch="ssh://xyz"> specifies a different host from the | ||
| 470 | # one that was passed to repo init. | ||
| 471 | _master_keys_lock.acquire() | ||
| 472 | try: | ||
| 473 | |||
| 474 | # Check to see whether we already think that the master is running; if we | ||
| 475 | # think it's already running, return right away. | ||
| 476 | if port is not None: | ||
| 477 | key = '%s:%s' % (host, port) | ||
| 478 | else: | ||
| 479 | key = host | ||
| 480 | |||
| 481 | if key in _master_keys: | ||
| 482 | return True | ||
| 483 | |||
| 484 | if not _ssh_master or 'GIT_SSH' in os.environ: | ||
| 485 | # Failed earlier, so don't retry. | ||
| 486 | return False | ||
| 487 | |||
| 488 | # We will make two calls to ssh; this is the common part of both calls. | ||
| 489 | command_base = ['ssh', | ||
| 490 | '-o', 'ControlPath %s' % ssh_sock(), | ||
| 491 | host] | ||
| 492 | if port is not None: | ||
| 493 | command_base[1:1] = ['-p', str(port)] | ||
| 494 | |||
| 495 | # Since the key wasn't in _master_keys, we think that master isn't running. | ||
| 496 | # ...but before actually starting a master, we'll double-check. This can | ||
| 497 | # be important because we can't tell that that 'git@myhost.com' is the same | ||
| 498 | # as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file. | ||
| 499 | check_command = command_base + ['-O', 'check'] | ||
| 500 | try: | ||
| 501 | Trace(': %s', ' '.join(check_command)) | ||
| 502 | check_process = subprocess.Popen(check_command, | ||
| 503 | stdout=subprocess.PIPE, | ||
| 504 | stderr=subprocess.PIPE) | ||
| 505 | check_process.communicate() # read output, but ignore it... | ||
| 506 | isnt_running = check_process.wait() | ||
| 507 | |||
| 508 | if not isnt_running: | ||
| 509 | # Our double-check found that the master _was_ infact running. Add to | ||
| 510 | # the list of keys. | ||
| 511 | _master_keys.add(key) | ||
| 512 | return True | ||
| 513 | except Exception: | ||
| 514 | # Ignore excpetions. We we will fall back to the normal command and print | ||
| 515 | # to the log there. | ||
| 516 | pass | ||
| 517 | |||
| 518 | command = command_base[:1] + ['-M', '-N'] + command_base[1:] | ||
| 519 | try: | ||
| 520 | Trace(': %s', ' '.join(command)) | ||
| 521 | p = subprocess.Popen(command) | ||
| 522 | except Exception as e: | ||
| 523 | _ssh_master = False | ||
| 524 | print('\nwarn: cannot enable ssh control master for %s:%s\n%s' | ||
| 525 | % (host, port, str(e)), file=sys.stderr) | ||
| 526 | return False | ||
| 527 | |||
| 528 | time.sleep(1) | ||
| 529 | ssh_died = (p.poll() is not None) | ||
| 530 | if ssh_died: | ||
| 531 | return False | ||
| 532 | |||
| 533 | _master_processes.append(p) | ||
| 534 | _master_keys.add(key) | ||
| 535 | return True | ||
| 536 | finally: | ||
| 537 | _master_keys_lock.release() | ||
| 538 | |||
| 539 | |||
| 540 | def close_ssh(): | ||
| 541 | global _master_keys_lock | ||
| 542 | |||
| 543 | terminate_ssh_clients() | ||
| 544 | |||
| 545 | for p in _master_processes: | ||
| 546 | try: | ||
| 547 | os.kill(p.pid, signal.SIGTERM) | ||
| 548 | p.wait() | ||
| 549 | except OSError: | ||
| 550 | pass | ||
| 551 | del _master_processes[:] | ||
| 552 | _master_keys.clear() | ||
| 553 | |||
| 554 | d = ssh_sock(create=False) | ||
| 555 | if d: | ||
| 556 | try: | ||
| 557 | platform_utils.rmdir(os.path.dirname(d)) | ||
| 558 | except OSError: | ||
| 559 | pass | ||
| 560 | |||
| 561 | # We're done with the lock, so we can delete it. | ||
| 562 | _master_keys_lock = None | ||
| 563 | |||
| 564 | |||
| 565 | URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):') | ||
| 566 | URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/') | 435 | URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/') |
| 567 | 436 | ||
| 568 | 437 | ||
| @@ -614,27 +483,6 @@ def GetUrlCookieFile(url, quiet): | |||
| 614 | yield cookiefile, None | 483 | yield cookiefile, None |
| 615 | 484 | ||
| 616 | 485 | ||
| 617 | def _preconnect(url): | ||
| 618 | m = URI_ALL.match(url) | ||
| 619 | if m: | ||
| 620 | scheme = m.group(1) | ||
| 621 | host = m.group(2) | ||
| 622 | if ':' in host: | ||
| 623 | host, port = host.split(':') | ||
| 624 | else: | ||
| 625 | port = None | ||
| 626 | if scheme in ('ssh', 'git+ssh', 'ssh+git'): | ||
| 627 | return _open_ssh(host, port) | ||
| 628 | return False | ||
| 629 | |||
| 630 | m = URI_SCP.match(url) | ||
| 631 | if m: | ||
| 632 | host = m.group(1) | ||
| 633 | return _open_ssh(host) | ||
| 634 | |||
| 635 | return False | ||
| 636 | |||
| 637 | |||
| 638 | class Remote(object): | 486 | class Remote(object): |
| 639 | """Configuration options related to a remote. | 487 | """Configuration options related to a remote. |
| 640 | """ | 488 | """ |
| @@ -673,7 +521,7 @@ class Remote(object): | |||
| 673 | 521 | ||
| 674 | def PreConnectFetch(self): | 522 | def PreConnectFetch(self): |
| 675 | connectionUrl = self._InsteadOf() | 523 | connectionUrl = self._InsteadOf() |
| 676 | return _preconnect(connectionUrl) | 524 | return ssh.preconnect(connectionUrl) |
| 677 | 525 | ||
| 678 | def ReviewUrl(self, userEmail, validate_certs): | 526 | def ReviewUrl(self, userEmail, validate_certs): |
| 679 | if self._review_url is None: | 527 | if self._review_url is None: |
| @@ -39,7 +39,7 @@ from color import SetDefaultColoring | |||
| 39 | import event_log | 39 | import event_log |
| 40 | from repo_trace import SetTrace | 40 | from repo_trace import SetTrace |
| 41 | from git_command import user_agent | 41 | from git_command import user_agent |
| 42 | from git_config import init_ssh, close_ssh, RepoConfig | 42 | from git_config import RepoConfig |
| 43 | from git_trace2_event_log import EventLog | 43 | from git_trace2_event_log import EventLog |
| 44 | from command import InteractiveCommand | 44 | from command import InteractiveCommand |
| 45 | from command import MirrorSafeCommand | 45 | from command import MirrorSafeCommand |
| @@ -56,6 +56,7 @@ from error import RepoChangedException | |||
| 56 | import gitc_utils | 56 | import gitc_utils |
| 57 | from manifest_xml import GitcClient, RepoClient | 57 | from manifest_xml import GitcClient, RepoClient |
| 58 | from pager import RunPager, TerminatePager | 58 | from pager import RunPager, TerminatePager |
| 59 | import ssh | ||
| 59 | from wrapper import WrapperPath, Wrapper | 60 | from wrapper import WrapperPath, Wrapper |
| 60 | 61 | ||
| 61 | from subcmds import all_commands | 62 | from subcmds import all_commands |
| @@ -592,7 +593,7 @@ def _Main(argv): | |||
| 592 | repo = _Repo(opt.repodir) | 593 | repo = _Repo(opt.repodir) |
| 593 | try: | 594 | try: |
| 594 | try: | 595 | try: |
| 595 | init_ssh() | 596 | ssh.init() |
| 596 | init_http() | 597 | init_http() |
| 597 | name, gopts, argv = repo._ParseArgs(argv) | 598 | name, gopts, argv = repo._ParseArgs(argv) |
| 598 | run = lambda: repo._Run(name, gopts, argv) or 0 | 599 | run = lambda: repo._Run(name, gopts, argv) or 0 |
| @@ -604,7 +605,7 @@ def _Main(argv): | |||
| 604 | else: | 605 | else: |
| 605 | result = run() | 606 | result = run() |
| 606 | finally: | 607 | finally: |
| 607 | close_ssh() | 608 | ssh.close() |
| 608 | except KeyboardInterrupt: | 609 | except KeyboardInterrupt: |
| 609 | print('aborted by user', file=sys.stderr) | 610 | print('aborted by user', file=sys.stderr) |
| 610 | result = 1 | 611 | result = 1 |
| @@ -0,0 +1,257 @@ | |||
| 1 | # Copyright (C) 2008 The Android Open Source Project | ||
| 2 | # | ||
| 3 | # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 4 | # you may not use this file except in compliance with the License. | ||
| 5 | # You may obtain a copy of the License at | ||
| 6 | # | ||
| 7 | # http://www.apache.org/licenses/LICENSE-2.0 | ||
| 8 | # | ||
| 9 | # Unless required by applicable law or agreed to in writing, software | ||
| 10 | # distributed under the License is distributed on an "AS IS" BASIS, | ||
| 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 12 | # See the License for the specific language governing permissions and | ||
| 13 | # limitations under the License. | ||
| 14 | |||
| 15 | """Common SSH management logic.""" | ||
| 16 | |||
| 17 | import functools | ||
| 18 | import os | ||
| 19 | import re | ||
| 20 | import signal | ||
| 21 | import subprocess | ||
| 22 | import sys | ||
| 23 | import tempfile | ||
| 24 | try: | ||
| 25 | import threading as _threading | ||
| 26 | except ImportError: | ||
| 27 | import dummy_threading as _threading | ||
| 28 | import time | ||
| 29 | |||
| 30 | import platform_utils | ||
| 31 | from repo_trace import Trace | ||
| 32 | |||
| 33 | |||
| 34 | _ssh_proxy_path = None | ||
| 35 | _ssh_sock_path = None | ||
| 36 | _ssh_clients = [] | ||
| 37 | |||
| 38 | |||
| 39 | def _run_ssh_version(): | ||
| 40 | """run ssh -V to display the version number""" | ||
| 41 | return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode() | ||
| 42 | |||
| 43 | |||
| 44 | def _parse_ssh_version(ver_str=None): | ||
| 45 | """parse a ssh version string into a tuple""" | ||
| 46 | if ver_str is None: | ||
| 47 | ver_str = _run_ssh_version() | ||
| 48 | m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str) | ||
| 49 | if m: | ||
| 50 | return tuple(int(x) for x in m.group(1).split('.')) | ||
| 51 | else: | ||
| 52 | return () | ||
| 53 | |||
| 54 | |||
| 55 | @functools.lru_cache(maxsize=None) | ||
| 56 | def version(): | ||
| 57 | """return ssh version as a tuple""" | ||
| 58 | try: | ||
| 59 | return _parse_ssh_version() | ||
| 60 | except subprocess.CalledProcessError: | ||
| 61 | print('fatal: unable to detect ssh version', file=sys.stderr) | ||
| 62 | sys.exit(1) | ||
| 63 | |||
| 64 | |||
| 65 | def proxy(): | ||
| 66 | global _ssh_proxy_path | ||
| 67 | if _ssh_proxy_path is None: | ||
| 68 | _ssh_proxy_path = os.path.join( | ||
| 69 | os.path.dirname(__file__), | ||
| 70 | 'git_ssh') | ||
| 71 | return _ssh_proxy_path | ||
| 72 | |||
| 73 | |||
| 74 | def add_client(p): | ||
| 75 | _ssh_clients.append(p) | ||
| 76 | |||
| 77 | |||
| 78 | def remove_client(p): | ||
| 79 | try: | ||
| 80 | _ssh_clients.remove(p) | ||
| 81 | except ValueError: | ||
| 82 | pass | ||
| 83 | |||
| 84 | |||
| 85 | def _terminate_clients(): | ||
| 86 | global _ssh_clients | ||
| 87 | for p in _ssh_clients: | ||
| 88 | try: | ||
| 89 | os.kill(p.pid, signal.SIGTERM) | ||
| 90 | p.wait() | ||
| 91 | except OSError: | ||
| 92 | pass | ||
| 93 | _ssh_clients = [] | ||
| 94 | |||
| 95 | |||
| 96 | _master_processes = [] | ||
| 97 | _master_keys = set() | ||
| 98 | _ssh_master = True | ||
| 99 | _master_keys_lock = None | ||
| 100 | |||
| 101 | |||
| 102 | def init(): | ||
| 103 | """Should be called once at the start of repo to init ssh master handling. | ||
| 104 | |||
| 105 | At the moment, all we do is to create our lock. | ||
| 106 | """ | ||
| 107 | global _master_keys_lock | ||
| 108 | assert _master_keys_lock is None, "Should only call init once" | ||
| 109 | _master_keys_lock = _threading.Lock() | ||
| 110 | |||
| 111 | |||
| 112 | def _open_ssh(host, port=None): | ||
| 113 | global _ssh_master | ||
| 114 | |||
| 115 | # Bail before grabbing the lock if we already know that we aren't going to | ||
| 116 | # try creating new masters below. | ||
| 117 | if sys.platform in ('win32', 'cygwin'): | ||
| 118 | return False | ||
| 119 | |||
| 120 | # Acquire the lock. This is needed to prevent opening multiple masters for | ||
| 121 | # the same host when we're running "repo sync -jN" (for N > 1) _and_ the | ||
| 122 | # manifest <remote fetch="ssh://xyz"> specifies a different host from the | ||
| 123 | # one that was passed to repo init. | ||
| 124 | _master_keys_lock.acquire() | ||
| 125 | try: | ||
| 126 | |||
| 127 | # Check to see whether we already think that the master is running; if we | ||
| 128 | # think it's already running, return right away. | ||
| 129 | if port is not None: | ||
| 130 | key = '%s:%s' % (host, port) | ||
| 131 | else: | ||
| 132 | key = host | ||
| 133 | |||
| 134 | if key in _master_keys: | ||
| 135 | return True | ||
| 136 | |||
| 137 | if not _ssh_master or 'GIT_SSH' in os.environ: | ||
| 138 | # Failed earlier, so don't retry. | ||
| 139 | return False | ||
| 140 | |||
| 141 | # We will make two calls to ssh; this is the common part of both calls. | ||
| 142 | command_base = ['ssh', | ||
| 143 | '-o', 'ControlPath %s' % sock(), | ||
| 144 | host] | ||
| 145 | if port is not None: | ||
| 146 | command_base[1:1] = ['-p', str(port)] | ||
| 147 | |||
| 148 | # Since the key wasn't in _master_keys, we think that master isn't running. | ||
| 149 | # ...but before actually starting a master, we'll double-check. This can | ||
| 150 | # be important because we can't tell that that 'git@myhost.com' is the same | ||
| 151 | # as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file. | ||
| 152 | check_command = command_base + ['-O', 'check'] | ||
| 153 | try: | ||
| 154 | Trace(': %s', ' '.join(check_command)) | ||
| 155 | check_process = subprocess.Popen(check_command, | ||
| 156 | stdout=subprocess.PIPE, | ||
| 157 | stderr=subprocess.PIPE) | ||
| 158 | check_process.communicate() # read output, but ignore it... | ||
| 159 | isnt_running = check_process.wait() | ||
| 160 | |||
| 161 | if not isnt_running: | ||
| 162 | # Our double-check found that the master _was_ infact running. Add to | ||
| 163 | # the list of keys. | ||
| 164 | _master_keys.add(key) | ||
| 165 | return True | ||
| 166 | except Exception: | ||
| 167 | # Ignore excpetions. We we will fall back to the normal command and print | ||
| 168 | # to the log there. | ||
| 169 | pass | ||
| 170 | |||
| 171 | command = command_base[:1] + ['-M', '-N'] + command_base[1:] | ||
| 172 | try: | ||
| 173 | Trace(': %s', ' '.join(command)) | ||
| 174 | p = subprocess.Popen(command) | ||
| 175 | except Exception as e: | ||
| 176 | _ssh_master = False | ||
| 177 | print('\nwarn: cannot enable ssh control master for %s:%s\n%s' | ||
| 178 | % (host, port, str(e)), file=sys.stderr) | ||
| 179 | return False | ||
| 180 | |||
| 181 | time.sleep(1) | ||
| 182 | ssh_died = (p.poll() is not None) | ||
| 183 | if ssh_died: | ||
| 184 | return False | ||
| 185 | |||
| 186 | _master_processes.append(p) | ||
| 187 | _master_keys.add(key) | ||
| 188 | return True | ||
| 189 | finally: | ||
| 190 | _master_keys_lock.release() | ||
| 191 | |||
| 192 | |||
| 193 | def close(): | ||
| 194 | global _master_keys_lock | ||
| 195 | |||
| 196 | _terminate_clients() | ||
| 197 | |||
| 198 | for p in _master_processes: | ||
| 199 | try: | ||
| 200 | os.kill(p.pid, signal.SIGTERM) | ||
| 201 | p.wait() | ||
| 202 | except OSError: | ||
| 203 | pass | ||
| 204 | del _master_processes[:] | ||
| 205 | _master_keys.clear() | ||
| 206 | |||
| 207 | d = sock(create=False) | ||
| 208 | if d: | ||
| 209 | try: | ||
| 210 | platform_utils.rmdir(os.path.dirname(d)) | ||
| 211 | except OSError: | ||
| 212 | pass | ||
| 213 | |||
| 214 | # We're done with the lock, so we can delete it. | ||
| 215 | _master_keys_lock = None | ||
| 216 | |||
| 217 | |||
| 218 | URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):') | ||
| 219 | URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/') | ||
| 220 | |||
| 221 | |||
| 222 | def preconnect(url): | ||
| 223 | m = URI_ALL.match(url) | ||
| 224 | if m: | ||
| 225 | scheme = m.group(1) | ||
| 226 | host = m.group(2) | ||
| 227 | if ':' in host: | ||
| 228 | host, port = host.split(':') | ||
| 229 | else: | ||
| 230 | port = None | ||
| 231 | if scheme in ('ssh', 'git+ssh', 'ssh+git'): | ||
| 232 | return _open_ssh(host, port) | ||
| 233 | return False | ||
| 234 | |||
| 235 | m = URI_SCP.match(url) | ||
| 236 | if m: | ||
| 237 | host = m.group(1) | ||
| 238 | return _open_ssh(host) | ||
| 239 | |||
| 240 | return False | ||
| 241 | |||
| 242 | def sock(create=True): | ||
| 243 | global _ssh_sock_path | ||
| 244 | if _ssh_sock_path is None: | ||
| 245 | if not create: | ||
| 246 | return None | ||
| 247 | tmp_dir = '/tmp' | ||
| 248 | if not os.path.exists(tmp_dir): | ||
| 249 | tmp_dir = tempfile.gettempdir() | ||
| 250 | if version() < (6, 7): | ||
| 251 | tokens = '%r@%h:%p' | ||
| 252 | else: | ||
| 253 | tokens = '%C' # hash of %l%h%p%r | ||
| 254 | _ssh_sock_path = os.path.join( | ||
| 255 | tempfile.mkdtemp('', 'ssh-', tmp_dir), | ||
| 256 | 'master-' + tokens) | ||
| 257 | return _ssh_sock_path | ||
diff --git a/tests/test_git_command.py b/tests/test_git_command.py index 76c092f4..93300a6f 100644 --- a/tests/test_git_command.py +++ b/tests/test_git_command.py | |||
| @@ -26,38 +26,6 @@ import git_command | |||
| 26 | import wrapper | 26 | import wrapper |
| 27 | 27 | ||
| 28 | 28 | ||
| 29 | class SSHUnitTest(unittest.TestCase): | ||
| 30 | """Tests the ssh functions.""" | ||
| 31 | |||
| 32 | def test_parse_ssh_version(self): | ||
| 33 | """Check parse_ssh_version() handling.""" | ||
| 34 | ver = git_command._parse_ssh_version('Unknown\n') | ||
| 35 | self.assertEqual(ver, ()) | ||
| 36 | ver = git_command._parse_ssh_version('OpenSSH_1.0\n') | ||
| 37 | self.assertEqual(ver, (1, 0)) | ||
| 38 | ver = git_command._parse_ssh_version('OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n') | ||
| 39 | self.assertEqual(ver, (6, 6, 1)) | ||
| 40 | ver = git_command._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n') | ||
| 41 | self.assertEqual(ver, (7, 6)) | ||
| 42 | |||
| 43 | def test_ssh_version(self): | ||
| 44 | """Check ssh_version() handling.""" | ||
| 45 | with mock.patch('git_command._run_ssh_version', return_value='OpenSSH_1.2\n'): | ||
| 46 | self.assertEqual(git_command.ssh_version(), (1, 2)) | ||
| 47 | |||
| 48 | def test_ssh_sock(self): | ||
| 49 | """Check ssh_sock() function.""" | ||
| 50 | with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'): | ||
| 51 | # old ssh version uses port | ||
| 52 | with mock.patch('git_command.ssh_version', return_value=(6, 6)): | ||
| 53 | self.assertTrue(git_command.ssh_sock().endswith('%p')) | ||
| 54 | git_command._ssh_sock_path = None | ||
| 55 | # new ssh version uses hash | ||
| 56 | with mock.patch('git_command.ssh_version', return_value=(6, 7)): | ||
| 57 | self.assertTrue(git_command.ssh_sock().endswith('%C')) | ||
| 58 | git_command._ssh_sock_path = None | ||
| 59 | |||
| 60 | |||
| 61 | class GitCallUnitTest(unittest.TestCase): | 29 | class GitCallUnitTest(unittest.TestCase): |
| 62 | """Tests the _GitCall class (via git_command.git).""" | 30 | """Tests the _GitCall class (via git_command.git).""" |
| 63 | 31 | ||
diff --git a/tests/test_ssh.py b/tests/test_ssh.py new file mode 100644 index 00000000..5a4f27e4 --- /dev/null +++ b/tests/test_ssh.py | |||
| @@ -0,0 +1,52 @@ | |||
| 1 | # Copyright 2019 The Android Open Source Project | ||
| 2 | # | ||
| 3 | # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 4 | # you may not use this file except in compliance with the License. | ||
| 5 | # You may obtain a copy of the License at | ||
| 6 | # | ||
| 7 | # http://www.apache.org/licenses/LICENSE-2.0 | ||
| 8 | # | ||
| 9 | # Unless required by applicable law or agreed to in writing, software | ||
| 10 | # distributed under the License is distributed on an "AS IS" BASIS, | ||
| 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 12 | # See the License for the specific language governing permissions and | ||
| 13 | # limitations under the License. | ||
| 14 | |||
| 15 | """Unittests for the ssh.py module.""" | ||
| 16 | |||
| 17 | import unittest | ||
| 18 | from unittest import mock | ||
| 19 | |||
| 20 | import ssh | ||
| 21 | |||
| 22 | |||
| 23 | class SshTests(unittest.TestCase): | ||
| 24 | """Tests the ssh functions.""" | ||
| 25 | |||
| 26 | def test_parse_ssh_version(self): | ||
| 27 | """Check _parse_ssh_version() handling.""" | ||
| 28 | ver = ssh._parse_ssh_version('Unknown\n') | ||
| 29 | self.assertEqual(ver, ()) | ||
| 30 | ver = ssh._parse_ssh_version('OpenSSH_1.0\n') | ||
| 31 | self.assertEqual(ver, (1, 0)) | ||
| 32 | ver = ssh._parse_ssh_version('OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n') | ||
| 33 | self.assertEqual(ver, (6, 6, 1)) | ||
| 34 | ver = ssh._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n') | ||
| 35 | self.assertEqual(ver, (7, 6)) | ||
| 36 | |||
| 37 | def test_version(self): | ||
| 38 | """Check version() handling.""" | ||
| 39 | with mock.patch('ssh._run_ssh_version', return_value='OpenSSH_1.2\n'): | ||
| 40 | self.assertEqual(ssh.version(), (1, 2)) | ||
| 41 | |||
| 42 | def test_ssh_sock(self): | ||
| 43 | """Check sock() function.""" | ||
| 44 | with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'): | ||
| 45 | # old ssh version uses port | ||
| 46 | with mock.patch('ssh.version', return_value=(6, 6)): | ||
| 47 | self.assertTrue(ssh.sock().endswith('%p')) | ||
| 48 | ssh._ssh_sock_path = None | ||
| 49 | # new ssh version uses hash | ||
| 50 | with mock.patch('ssh.version', return_value=(6, 7)): | ||
| 51 | self.assertTrue(ssh.sock().endswith('%C')) | ||
| 52 | ssh._ssh_sock_path = None | ||
