#!/usr/bin/python

# osh
# Copyright (C) 2005 Jack Orenstein <jao@geophile.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.


# stdin carries 2 or 3 inputs:
# 1) verbosity flag
# 2) pipeline to be executed
# 3) optional: kill signal
#
# The main thread processes stdin. Command execution takes place
# in a separate thread (_PipelineRunner). Termination occurs in
# one of two ways:
# 1) kill signal arrives: The kill is applied to this process and all
#    descendents.
# 2) Command terminates, (normally or with exception): kill self and
#    all descendents anyway (but from _PipelineRunner). This is the
#    only way to get the wait for the kill signal to stop if the
#    client doesn't end first.
#
# The _PipelineRunner thread is never joined! The reason is that the
# process will be killed by one thread or the other.

import cPickle
import os
import sys
import threading
import traceback

import osh.error
import osh.core
import osh.process

pid = os.getpid()

closed_streams = False

TRACE = True
_tracefile = None

def trace(line):
    if TRACE:
        global _tracefile
        if _tracefile is None:
            _tracefile = open('/tmp/trace', 'w')
        print >>_tracefile, line
        _tracefile.flush()

def flush_trace():
    if TRACE:
        global _tracefile
        if _tracefile:
            _tracefile.flush()

def _kill_self_and_descendents(kill_signal = None):
    trace('>>> In _kill_self_and_descendents for process %s' % pid)
    try:
        this_process = osh.process.Process(pid)
        for descendent in this_process.descendents():
            trace('>>> killing %s' % descendent.pid())
            descendent.kill(kill_signal)
        trace('>>> killing %s' % this_process.pid())
        this_process.kill(kill_signal)
    except:
        trace('>>> exception while killing self: %s' % str(e))
        traceback.print_exc(file = _tracefile)
        flush_trace()

def _shutdown():
    global closed_streams
    if not closed_streams:
        trace('Closing stdout and stderr')
        sys.stdout.flush()
        sys.stderr.flush()
        sys.stdout.close()
        sys.stderr.close()
        closed_streams = True

class _Pickler(osh.core.Op):
    _output = None

    def __init__(self):
        osh.core.Op.__init__(self, '', (0, 0))
        self._output = cPickle.Pickler(sys.stdout)
        def _remoteosh_exception_handler(exception, op, input, host = None):
            trace('Handling exception on %s(%s): %s' % (op, input, exception))
            self._output.dump(osh.error.Error(str(op), input, exception))
        osh.error.set_exception_handler(_remoteosh_exception_handler)

    def __str__(self):
        return '_Pickler#%s' % self.id()

    def setup(self):
        pass

    def receive(self, object):
        for i in xrange(len(object)):
            oi = object[i]
            trace('object[%s]: (%s) %s' % (i, oi.__class__, oi))
        self._output.dump(object)

    def receive_complete(self):
        # If command throws an exception but does send_complete in a finally
        # block, we can get here before the exception handler is invoked.
        # That means that the dump of an Error to stdout will be missed
        # because _shutdown closes stdout.
        # _shutdown()
        pass

class _PipelineRunner(threading.Thread):

    _pipeline = None

    def __init__(self, pipeline):
        threading.Thread.__init__(self)
        self._pipeline = pipeline

    def run(self):
        try:
            try:
                self._pipeline.append_op(_Pickler())
                trace(('pipeline: %s') % self._pipeline.dump())
                self._pipeline.setup()
                self._pipeline.execute()
                trace('done')
            except Exception, e:
                trace('Caught exception during execution: %s' % str(e))
                traceback.print_exc(file = _tracefile)
                flush_trace()
                osh.error.exception_handler(e, None, None)
        finally:
            _shutdown()
            trace('About to kill self and descendents')
            _kill_self_and_descendents()

input = cPickle.Unpickler(sys.stdin)
try:
    osh.core.verbosity = input.load()
    trace('verbosity: %s' % osh.core.verbosity)
    # osh_usage controls error handling. On remote side (i.e., here),
    # do CLI error handling -- write to e stream. Caller will deal with it.
    osh.core.osh_usage = osh.core.USAGE_CLI
    if len(sys.argv) > 1:
        osh.core.default_db_profile = sys.argv[1]
    pipeline = input.load()
    pipeline_runner = _PipelineRunner(pipeline)
    pipeline_runner.start()
    # Wait for kill signal that may never come
    try:
        kill_signal = input.load()
        trace('Received kill signal %s' % kill_signal)
        _kill_self_and_descendents(kill_signal)
    except EOFError, e:
        trace('EOFError waiting for kill signal')
        trace(str(e))
        traceback.print_exc(file = _tracefile)
        flush_trace()
        _kill_self_and_descendents(9)
except Exception, e:
    trace(str(e))
    traceback.print_exc(file = _tracefile)
    flush_trace()
