Package osh :: Package command :: Module fork
[frames] | no frames]

Source Code for Module osh.command.fork

  1  # osh 
  2  # Copyright (C) Jack Orenstein <jao@geophile.com> 
  3  # 
  4  # This program is free software; you can redistribute it and/or modify 
  5  # it under the terms of the GNU General Public License as published by 
  6  # the Free Software Foundation; either version 2 of the License, or 
  7  # (at your option) any later version. 
  8  # 
  9  # This program is distributed in the hope that it will be useful, 
 10  # but WITHOUT ANY WARRANTY; without even the implied warranty of 
 11  # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
 12  # GNU General Public License for more details. 
 13  # 
 14  # You should have received a copy of the GNU General Public License 
 15  # along with this program; if not, write to the Free Software 
 16  # Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. 
 17   
 18  """For API usage only, (for CLI use C{osh @FORK [ ... ]} syntax instead.) 
 19  """ 
 20   
 21  import types 
 22   
 23  import osh.args 
 24  import osh.cluster 
 25  import osh.core 
 26  import osh.error 
 27  import osh.function 
 28  import osh.oshthread 
 29  import osh.spawn 
 30  import osh.util 
 31  import merge 
 32   
 33  LineOutputConsumer = osh.spawn.LineOutputConsumer 
 34  ObjectInputProvider = osh.spawn.ObjectInputProvider 
 35  ObjectOutputConsumer = osh.spawn.ObjectOutputConsumer 
 36  Spawn = osh.spawn.Spawn 
 37  Option = osh.args.Option 
 38   
 39  # CLI 
40 -def _fork():
41 return _Fork()
42 43 # API
44 -def fork(threadgen, command, merge_key = None):
45 """Creates threads and executes C{command} on each. The number of threads is determined 46 by C{threadgen}. If C{threadgen} is an integer, then the specified number of threads is created, 47 and each thread has an integer label, from 0 through C{threadgen} - 1. If C{threadgen} is 48 a sequence, then for each element in the sequence, a thread is created, labelled with that 49 element. If C{threadgen} is a function, then it is evaluated, and is expected to yield an 50 integer or sequence, which is then handled as already described. If C{threadgen} is 51 a cluster specification, then the command is executed on each specified host; the thread label 52 identifies the host, (whose type is C{osh.cluster.Host}). If C{merge_key} is specified, then 53 the inputs of each thread are expected to be ordered by the C{merge_key}. The sequences 54 from the threads 55 are then merged into a single sequence using the C{merge_key}. 56 """ 57 import osh.apiparser 58 op = _Fork() 59 if isinstance(command, osh.core.Op): 60 command = [command] 61 pipeline = osh.apiparser._sequence_op(command) 62 if merge_key: 63 return op.process_args(threadgen, pipeline, merge_key) 64 else: 65 return op.process_args(threadgen, pipeline)
66
67 -class _Fork(osh.core.Generator):
68 69 # state 70 71 _threads = None 72 _pipeline = None 73 _merge_key = None 74 _function_store = None 75 _cluster_required = None 76 77 # object interface 78
79 - def __init__(self):
80 osh.core.Generator.__init__(self, '', (2, 3)) 81 self._function_store = FunctionStore() 82 self._cluster_required = False
83 84 85 # BaseOp interface 86
87 - def doc(self):
88 return __doc__
89
90 - def setup(self):
91 args = self.args() 92 threadgen = args.next() 93 self._pipeline = args.next() 94 self._merge_key = args.next() 95 cluster, thread_ids = self.thread_ids(threadgen) 96 self.setup_pipeline(cluster) 97 self.setup_threads(thread_ids) 98 self.setup_shared_state()
99 100 # generator interface 101
102 - def execute(self):
103 for thread in self._threads: 104 thread.pipeline.setup() 105 thread.pipeline.set_receiver(self._receiver) 106 thread.start() 107 for thread in self._threads: 108 while thread.isAlive(): 109 thread.join(0.1) 110 thread_termination = thread.terminating_exception 111 if thread_termination: 112 osh.error.exception_handler(thread_termination, self, None, thread)
113 114 # For use by this package 115
116 - def _set_cluster_required(self, required):
117 self._cluster_required = required
118 119 # For use by this class 120
121 - def thread_ids(self, threadgen, already_evaled = False):
122 threadgen_type = type(threadgen) 123 try: 124 cluster = None 125 thread_ids = None 126 if threadgen_type in (list, tuple): 127 thread_ids = threadgen 128 elif isinstance(threadgen, int): 129 thread_ids = range(threadgen) 130 elif threadgen.isdigit(): 131 thread_ids = range(int(threadgen)) 132 elif threadgen_type is types.FunctionType: 133 if already_evaled: 134 self.usage() 135 else: 136 cluster, thread_ids = self.thread_ids(osh.function._Function(threadgen)(), True) 137 else: 138 # String, which might be a CLI function invocation 139 cluster_name, pattern = (threadgen.split(':') + [None])[:2] 140 cluster = osh.cluster.cluster_named(cluster_name, pattern) 141 if cluster: 142 thread_ids = cluster.hosts 143 else: 144 evaled_threadgen = osh.function._Function(threadgen)() 145 cluster, thread_ids = self.thread_ids(evaled_threadgen, True) 146 if self._cluster_required and cluster is None: 147 # API invoked remote but did not identify a cluster 148 import remote 149 self.usage(remote.__doc__) 150 if thread_ids is None: 151 self.usage() 152 return cluster, thread_ids 153 except: 154 self.usage()
155
156 - def setup_pipeline(self, cluster):
157 if cluster and not self._pipeline.run_local(): 158 remote_op = _Remote() 159 remote_op.process_args(self._pipeline) 160 self._pipeline = osh.core.Pipeline() 161 self._pipeline.append_op(remote_op) 162 self._pipeline.append_op(_AttachThreadState()) 163 self._pipeline.append_op(merge.merge(self._merge_key))
164
165 - def setup_threads(self, thread_ids):
166 pipeline_copier = _PipelineCopier(self) 167 # Use FunctionStore to hide functions during pipeline copying 168 self._function_store.hide_functions(self._pipeline) 169 threads = [] 170 for thread_id in thread_ids: 171 pipeline_copy = pipeline_copier.pipeline(thread_id) 172 thread = osh.oshthread._OshThread(self, thread_id, pipeline_copy) 173 threads.append(thread) 174 self._function_store.restore_functions(self._pipeline) 175 self._threads = threads
176
177 - def setup_shared_state(self):
178 # Set up shared state for each command in the pipeline: Traverse the pipelines 179 # in parallel. Use self._pipeline to allocate the shared state, and then pass the 180 # state to each copy. 181 pipeline_copy_iterators = [thread.pipeline.ops() for thread in self._threads] 182 for pipeline_template_op in self._pipeline.ops(): 183 command_state = pipeline_template_op.create_command_state(self._threads) 184 for pipeline_copy_iterator in pipeline_copy_iterators: 185 pipeline_copy_op = pipeline_copy_iterator.next() 186 pipeline_copy_op.set_command_state(command_state)
187
188 -class _PipelineCopier(object):
189 190 _fork = None 191
192 - def __init__(self, fork):
193 self._fork = fork
194
195 - def pipeline(self, thread_state):
196 copy = osh.util.clone(self._fork._pipeline) 197 self._fork._function_store.restore_functions(copy) 198 return copy
199
200 -class _AttachThreadState(osh.core.Op):
201 202 _thread_state = None 203
204 - def __init__(self):
205 osh.core.Op.__init__(self, '', (0, 0))
206
207 - def setup(self):
208 self._thread_state = (self.thread_state,)
209
210 - def receive(self, object):
211 if type(object) is list: 212 object = tuple(object) 213 self.send(self._thread_state + object)
214 215 # osh needs to copy pipelines to support forks. 216 # 1. Pickling: doesn't handle functions. 217 # 2. Marshaling: doesn't handle recursive structures. Pipelines are recursive due to BaseOp.parent. 218 # 3. Add a clone method to the BaseOp interface. Lots of work to handle recursion. 219 # This implementation is a combination of 1 and 3: BaseOp.replace_function_by_reference 220 # replaces functions by integer references to functions. BaseOp.restore_function does the 221 # inverse. A pipeline is copied by: 222 # - Apply replace_function_by_reference recursively to the input pipeline. 223 # - Copy the pipeline. 224 # - Apply restore_function to the copy. 225
226 -class FunctionReference(int):
227 pass
228
229 -class FunctionStore(object):
230 231 _functions = None 232
233 - def __init__(self):
234 self._functions = []
235 236 # For use by this module 237
238 - def hide_functions(self, pipeline):
239 pipeline.replace_function_by_reference(self)
240
241 - def restore_functions(self, pipeline):
242 pipeline.restore_function(self)
243 244 # For use by BaseOp subclasses in hiding and restoring functions 245
246 - def function_to_reference(self, x):
247 if type(x) is types.FunctionType: 248 position = FunctionReference(len(self._functions)) 249 self._functions.append(x) 250 return position 251 else: 252 return x
253
254 - def reference_to_function(self, x):
255 if type(x) is FunctionReference: 256 return self._functions[x] 257 else: 258 return x
259 260 # Remote execution 261 262 _REMOTE_EXECUTABLE = 'remoteosh' 263
264 -def _dump(stream, object):
265 stream.dump(object)
266
267 -def _consume_remote_stdout(consumer, threadid, object):
268 if isinstance(object, osh.error.PickleableException): 269 exception = object.recreate_exception() 270 osh.error.exception_handler(exception, object.command_description(), object.input(), threadid) 271 else: 272 consumer.send(object)
273
274 -def _consume_remote_stderr(consumer, threadid, line):
275 # UGLY HACK: remoteosh can occasionally return "[Errno 9] Bad file descriptor" on stderr. 276 # I think this is because of io to a process stream whose process has completed. 277 # I haven't had luck in tracking this down and fixing the problem for real, so 278 # this is a grotesque workaround. 279 if '[Errno 9] Bad file descriptor' not in line: 280 osh.error.stderr_handler(line, consumer, None, threadid)
281
282 -class _Remote(osh.core.Generator):
283 284 # state 285 286 _pipeline = None 287 288 # object interface 289
290 - def __init__(self):
291 osh.core.Generator.__init__(self, '', (1, 1))
292 293 # BaseOp interface 294
295 - def doc(self):
296 return __doc__
297
298 - def setup(self):
299 self._pipeline = self.args().next()
300 301 # generator interface 302
303 - def execute(self):
304 host = self.thread_state 305 process = Spawn( 306 self._remote_command(host.address, host.user, host.db_profile), 307 ObjectInputProvider(lambda stream, object: _dump(stream, object), 308 [osh.core.verbosity, self._pipeline, self.thread_state]), 309 ObjectOutputConsumer(lambda object: _consume_remote_stdout(self, host, object)), 310 LineOutputConsumer(lambda line: _consume_remote_stderr(self, host, line))) 311 process.run() 312 if process.terminating_exception(): 313 e = process.terminating_exception() 314 raise e
315 316 # for use by this class 317
318 - def _remote_command(self, host, user, db_profile):
319 remote_command = [_REMOTE_EXECUTABLE] 320 if db_profile: 321 remote_command.append(db_profile) 322 ssh_command = 'ssh %s -l %s %s' % (host, user, ' '.join(remote_command)) 323 return ssh_command
324