#!/usr/bin/python

# osh
# Copyright (C) 2005 Jack Orenstein <jao@geophile.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.

"""oshtestssh -c CLUSTER
"""

import getopt
import sys
import threading;

from osh.config import *
import osh.spawn
import osh.util
import osh.cluster

SpawnSSH = osh.spawn.SpawnSSH
LineOutputConsumer = osh.spawn.LineOutputConsumer

def usage():
    print __doc__
    sys.exit(1)

def print_stdout(host, line):
    sys.stdout.write('%s stdout: %s' % (host, line))

def print_stderr(host, line):
    sys.stdout.write('%s stderr: %s' % (host, line))

class Tester(threading.Thread):

    _user = None
    _host = None

    def __init__(self, user, host):
        threading.Thread.__init__(self)
        self._user = user
        self._host = host

    def run(self):
        SpawnSSH(self._user,
                 self._host.address,
                 "echo hello",
                 None,
                 LineOutputConsumer(lambda line: print_stdout(self._host.name, line)),
                 LineOutputConsumer(lambda line: print_stderr(self._host.name, line))).run()

def test_connection(cluster):
    threads = []
    for host in cluster.hosts:
        tester = Tester(cluster.user, host)
        tester.start()
        threads.append(tester)
    for thread in threads:
        while thread.isAlive():
            thread.join(1.0)

options, args = getopt.getopt(sys.argv[1:], 'c:')
if args:
    usage()
cluster_name = None
for option in options:
    if option[0] == '-c':
        cluster_name = option[1]
if not cluster_name:
    usage()
cluster = osh.cluster.cluster_named(cluster_name)
test_connection(cluster)
