Package osh :: Package command :: Module fork
[frames] | no frames]

Source Code for Module osh.command.fork

  1  # osh 
  2  # Copyright (C) Jack Orenstein <jao@geophile.com> 
  3  # 
  4  # This program is free software; you can redistribute it and/or modify 
  5  # it under the terms of the GNU General Public License as published by 
  6  # the Free Software Foundation; either version 2 of the License, or 
  7  # (at your option) any later version. 
  8  # 
  9  # This program is distributed in the hope that it will be useful, 
 10  # but WITHOUT ANY WARRANTY; without even the implied warranty of 
 11  # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
 12  # GNU General Public License for more details. 
 13  # 
 14  # You should have received a copy of the GNU General Public License 
 15  # along with this program; if not, write to the Free Software 
 16  # Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. 
 17   
 18  """For API usage only, (for CLI use C{osh @FORK [ ... ]} syntax instead.) 
 19  """ 
 20   
 21  import types 
 22   
 23  import osh.args 
 24  import osh.cluster 
 25  import osh.core 
 26  import osh.error 
 27  import osh.function 
 28  import osh.oshthread 
 29  import osh.spawn 
 30  import osh.util 
 31  import merge 
 32   
 33  LineOutputConsumer = osh.spawn.LineOutputConsumer 
 34  ObjectInputProvider = osh.spawn.ObjectInputProvider 
 35  ObjectOutputConsumer = osh.spawn.ObjectOutputConsumer 
 36  Spawn = osh.spawn.Spawn 
 37  Option = osh.args.Option 
 38  create_function = osh.function._create_function 
 39   
 40  # CLI 
41 -def _fork():
42 return _Fork()
43 44 # API
45 -def fork(threadgen, command, merge_key = None):
46 """Creates threads and executes C{command} on each. The number of threads is determined 47 by C{threadgen}. If C{threadgen} is an integer, then the specified number of threads is created, 48 and each thread has an integer label, from 0 through C{threadgen} - 1. If C{threadgen} is 49 a sequence, then for each element in the sequence, a thread is created, labelled with that 50 element. If C{threadgen} is a function, then it is evaluated, and is expected to yield an 51 integer or sequence, which is then handled as already described. If C{threadgen} is 52 a cluster specification, then the command is executed on each specified host; the thread label 53 identifies the host, (whose type is C{osh.cluster.Host}). If C{merge_key} is specified, then 54 the inputs of each thread are expected to be ordered by the C{merge_key}. The sequences 55 from the threads 56 are then merged into a single sequence using the C{merge_key}. 57 """ 58 import osh.apiparser 59 op = _Fork() 60 if isinstance(command, osh.core.Op): 61 command = [command] 62 pipeline = osh.apiparser._sequence_op(command) 63 if merge_key: 64 return op.process_args(threadgen, pipeline, merge_key) 65 else: 66 return op.process_args(threadgen, pipeline)
67
68 -class _Fork(osh.core.Generator):
69 70 # state 71 72 _threads = None 73 _pipeline = None 74 _merge_key = None 75 _function_store = None 76 _cluster_required = None 77 78 # object interface 79
80 - def __init__(self):
81 osh.core.Generator.__init__(self, '', (2, 3)) 82 self._function_store = FunctionStore() 83 self._cluster_required = False
84 85 86 # BaseOp interface 87
88 - def doc(self):
89 return __doc__
90
91 - def setup(self):
92 args = self.args() 93 threadgen = args.next() 94 self._pipeline = args.next() 95 self._merge_key = args.next() 96 cluster, thread_ids = self.thread_ids(threadgen) 97 self.setup_pipeline(cluster) 98 self.setup_threads(thread_ids) 99 self.setup_shared_state()
100
101 - def receive_complete(self):
102 for thread in self._threads: 103 thread.pipeline.receive_complete()
104 105 106 # generator interface 107
108 - def execute(self):
109 for thread in self._threads: 110 thread.pipeline.setup() 111 thread.pipeline.set_receiver(self._receiver) 112 thread.start() 113 for thread in self._threads: 114 while thread.isAlive(): 115 thread.join(0.1) 116 thread_termination = thread.terminating_exception 117 if thread_termination: 118 osh.error.exception_handler(thread_termination, self, None, thread)
119 120 # For use by this package 121
122 - def _set_cluster_required(self, required):
123 self._cluster_required = required
124 125 # For use by this class 126
127 - def thread_ids(self, threadgen, already_evaled = False):
128 threadgen_type = type(threadgen) 129 try: 130 cluster = None 131 thread_ids = None 132 if threadgen_type in (list, tuple): 133 thread_ids = threadgen 134 elif isinstance(threadgen, int): 135 thread_ids = range(threadgen) 136 elif threadgen.isdigit(): 137 thread_ids = range(int(threadgen)) 138 elif threadgen_type is types.FunctionType: 139 if already_evaled: 140 self.usage() 141 else: 142 cluster, thread_ids = self.thread_ids(create_function(threadgen)(), True) 143 else: 144 # String, which might be a CLI function invocation 145 cluster_name, pattern = (threadgen.split(':') + [None])[:2] 146 cluster = osh.cluster.cluster_named(cluster_name, pattern) 147 if cluster: 148 thread_ids = cluster.hosts 149 else: 150 evaled_threadgen = create_function(threadgen)() 151 cluster, thread_ids = self.thread_ids(evaled_threadgen, True) 152 if self._cluster_required and cluster is None: 153 # API invoked remote but did not identify a cluster 154 import remote 155 self.usage(remote.__doc__) 156 if thread_ids is None: 157 self.usage() 158 return cluster, thread_ids 159 except: 160 self.usage()
161
162 - def setup_pipeline(self, cluster):
163 if cluster and not self._pipeline.run_local(): 164 remote_op = _Remote() 165 remote_op.process_args(self._pipeline) 166 self._pipeline = osh.core.Pipeline() 167 self._pipeline.append_op(remote_op) 168 self._pipeline.append_op(_AttachThreadState()) 169 self._pipeline.append_op(merge.merge(self._merge_key))
170
171 - def setup_threads(self, thread_ids):
172 pipeline_copier = _PipelineCopier(self) 173 # Use FunctionStore to hide functions during pipeline copying 174 self._function_store.hide_functions(self._pipeline) 175 threads = [] 176 for thread_id in thread_ids: 177 pipeline_copy = pipeline_copier.pipeline(thread_id) 178 thread = osh.oshthread._OshThread(self, thread_id, pipeline_copy) 179 threads.append(thread) 180 self._function_store.restore_functions(self._pipeline) 181 self._threads = threads
182
183 - def setup_shared_state(self):
184 # Set up shared state for each command in the pipeline: Traverse the pipelines 185 # in parallel. Use self._pipeline to allocate the shared state, and then pass the 186 # state to each copy. 187 pipeline_copy_iterators = [thread.pipeline.ops() for thread in self._threads] 188 for pipeline_template_op in self._pipeline.ops(): 189 command_state = pipeline_template_op.create_command_state(self._threads) 190 for pipeline_copy_iterator in pipeline_copy_iterators: 191 pipeline_copy_op = pipeline_copy_iterator.next() 192 pipeline_copy_op.set_command_state(command_state)
193
194 -class _PipelineCopier(object):
195 196 _fork = None 197
198 - def __init__(self, fork):
199 self._fork = fork
200
201 - def pipeline(self, thread_state):
202 copy = osh.util.clone(self._fork._pipeline) 203 self._fork._function_store.restore_functions(copy) 204 return copy
205
206 -class _AttachThreadState(osh.core.Op):
207 208 _thread_state = None 209
210 - def __init__(self):
211 osh.core.Op.__init__(self, '', (0, 0))
212
213 - def setup(self):
214 self._thread_state = (self.thread_state,)
215
216 - def receive(self, object):
217 if type(object) is list: 218 object = tuple(object) 219 self.send(self._thread_state + object)
220 221 # osh needs to copy pipelines to support forks. 222 # 1. Pickling: doesn't handle functions. 223 # 2. Marshaling: doesn't handle recursive structures. Pipelines are recursive due to BaseOp.parent. 224 # 3. Add a clone method to the BaseOp interface. Lots of work to handle recursion. 225 # This implementation is a combination of 1 and 3: BaseOp.replace_function_by_reference 226 # replaces functions by integer references to functions. BaseOp.restore_function does the 227 # inverse. A pipeline is copied by: 228 # - Apply replace_function_by_reference recursively to the input pipeline. 229 # - Copy the pipeline. 230 # - Apply restore_function to the copy. 231
232 -class FunctionReference(int):
233 pass
234
235 -class FunctionStore(object):
236 237 _functions = None 238
239 - def __init__(self):
240 self._functions = []
241 242 # For use by this module 243
244 - def hide_functions(self, pipeline):
245 pipeline.replace_function_by_reference(self)
246
247 - def restore_functions(self, pipeline):
248 pipeline.restore_function(self)
249 250 # For use by BaseOp subclasses in hiding and restoring functions 251
252 - def function_to_reference(self, x):
253 if type(x) is types.FunctionType: 254 position = FunctionReference(len(self._functions)) 255 self._functions.append(x) 256 return position 257 else: 258 return x
259
260 - def reference_to_function(self, x):
261 if type(x) is FunctionReference: 262 return self._functions[x] 263 else: 264 return x
265 266 # Remote execution 267 268 _REMOTE_EXECUTABLE = 'remoteosh' 269
270 -def _dump(stream, object):
271 stream.dump(object)
272
273 -def _consume_remote_stdout(consumer, threadid, object):
274 if isinstance(object, osh.error.PickleableException): 275 exception = object.recreate_exception() 276 osh.error.exception_handler(exception, object.command_description(), object.input(), threadid) 277 else: 278 consumer.send(object)
279
280 -def _consume_remote_stderr(consumer, threadid, line):
281 # UGLY HACK: remoteosh can occasionally return "[Errno 9] Bad file descriptor" on stderr. 282 # I think this is because of io to a process stream whose process has completed. 283 # I haven't had luck in tracking this down and fixing the problem for real, so 284 # this is a grotesque workaround. 285 if '[Errno 9] Bad file descriptor' not in line: 286 osh.error.stderr_handler(line, consumer, None, threadid)
287
288 -class _Remote(osh.core.Generator):
289 290 # state 291 292 _pipeline = None 293 294 # object interface 295
296 - def __init__(self):
297 osh.core.Generator.__init__(self, '', (1, 1))
298 299 # BaseOp interface 300
301 - def doc(self):
302 return __doc__
303
304 - def setup(self):
305 self._pipeline = self.args().next()
306 307 # generator interface 308
309 - def execute(self):
310 host = self.thread_state 311 process = Spawn( 312 self._remote_command(host.address, host.user, host.identity, host.db_profile), 313 ObjectInputProvider(lambda stream, object: _dump(stream, object), 314 [osh.core.verbosity, self._pipeline, self.thread_state]), 315 ObjectOutputConsumer(lambda object: _consume_remote_stdout(self, host, object)), 316 LineOutputConsumer(lambda line: _consume_remote_stderr(self, host, line))) 317 process.run() 318 if process.terminating_exception(): 319 raise process.terminating_exception()
320 321 # for use by this class 322
323 - def _remote_command(self, host, user, identity, db_profile):
324 buffer = [_REMOTE_EXECUTABLE] 325 if db_profile: 326 buffer.append(db_profile) 327 remote_command = ' '.join(buffer) 328 if identity: 329 ssh_command = 'ssh %s -l %s -i %s %s' % (host, 330 user, 331 identity, 332 remote_command) 333 else: 334 ssh_command = 'ssh %s -l %s %s' % (host, 335 user, 336 remote_command) 337 return ssh_command
338