class Pipeline(BasePipeline):
"""Local pipeline implementation using `multiprocessing`."""
def run(
self,
parameters: Optional[Dict[str, Dict[str, Any]]] = None,
use_cache: bool = True,
) -> "Distiset":
"""Runs the pipeline.
Args:
parameters: A dictionary with the step name as the key and a dictionary with
the runtime parameters for the step as the value. Defaults to `None`.
use_cache: Whether to use the cache from previous pipeline runs. Defaults to
`True`.
Returns:
The `Distiset` created by the pipeline.
Raises:
RuntimeError: If the pipeline fails to load all the steps.
"""
log_queue = mp.Queue()
# We must place the runtime parameters before calling setup_logging to ensure consistency
super().run(parameters, use_cache)
setup_logging(log_queue, filename=str(self._cache_location["log_file"])) # type: ignore
self._logger = logging.getLogger("distilabel.pipeline.local")
if self._dry_run:
# This message is placed here to ensure we are using the already setup logger.
self._logger.info("🌵 Dry run mode")
if self._batch_manager is None:
self._batch_manager = _BatchManager.from_dag(self.dag)
# If the batch manager is not able to generate batches, that means that the loaded
# `_BatchManager` from cache didn't have any remaining batches to process i.e.
# the previous pipeline execution was completed successfully.
if not self._batch_manager.can_generate():
self._logger.info(
"💾 Loaded batch manager from cache doesn't have any remaining data. Returning"
" `Distiset` from cache data..."
)
stop_logging()
return create_distiset(
self._cache_location["data"],
pipeline_path=self._cache_location["pipeline"],
log_filename_path=self._cache_location["log_file"],
enable_metadata=self._enable_metadata,
)
buffer_data_path = self._cache_location["data"]
self._logger.info(f"📝 Pipeline data will be written to '{buffer_data_path}'")
write_buffer = _WriteBuffer(buffer_data_path, self.dag.leaf_steps)
num_processes = len(self.dag)
ctx = mp.get_context() # type: ignore
with ctx.Manager() as manager, ctx.Pool(
num_processes,
initializer=_init_worker,
initargs=(log_queue,),
) as pool:
self.output_queue: "Queue[Any]" = manager.Queue()
self.shared_info = self._create_shared_info_dict(manager)
self._handle_keyboard_interrupt(manager=manager, pool=pool)
# Run the steps using the pool of processes
self._run_steps_in_loop(pool, manager, self.output_queue, self.shared_info)
# Wait for all the steps to be loaded correctly
if not self._all_steps_loaded():
write_buffer.close()
self._batch_manager = None
stop_logging()
raise RuntimeError(
"Failed to load all the steps. Could not run pipeline."
) from _SUBPROCESS_EXCEPTION
# Send the "first" batches to the steps so the batches starts flowing through
# the input queues and output queue
self._request_initial_batches()
# Start a loop to receive the output batches from the steps
self._run_output_queue_loop_in_thread(write_buffer)
# Send `None` to steps `input_queue`s just in case some step is still waiting
self._notify_steps_to_stop()
pool.close()
pool.join()
write_buffer.close()
distiset = create_distiset(
self._cache_location["data"],
pipeline_path=self._cache_location["pipeline"],
log_filename_path=self._cache_location["log_file"],
enable_metadata=self._enable_metadata,
)
stop_logging()
return distiset
def _run_output_queue_loop_in_thread(self, write_buffer: "_WriteBuffer") -> None:
"""Runs the output queue loop in a separate thread to receive the output batches
from the steps. This is done to avoid the signal handler to block the loop, which
would prevent the pipeline from stopping correctly.
Args:
write_buffer: The write buffer to write the data from the leaf steps to disk.
"""
thread = threading.Thread(target=self._output_queue_loop, args=(write_buffer,))
thread.start()
thread.join()
def _notify_steps_to_stop(self) -> None:
"""Notifies the steps to stop their infinite running loop by sending `None` to
their input queues."""
for step_name in self.dag:
if input_queue := self.dag.get_step(step_name).get(INPUT_QUEUE_ATTR_NAME):
input_queue.put(None)
def _output_queue_loop(self, write_buffer: "_WriteBuffer") -> None:
"""Loop to receive the output batches from the steps and manage the flow of the
batches through the pipeline.
Args:
write_buffer: The write buffer to write the data from the leaf steps to disk.
"""
while self._batch_manager.can_generate() and not _STOP_CALLED: # type: ignore
self._logger.debug("Waiting for output batch from step...")
if (batch := self.output_queue.get()) is None:
self._logger.debug("Received `None` from output queue. Breaking loop.")
break
if batch.step_name in self.dag.leaf_steps:
write_buffer.add_batch(batch)
# If `_STOP_CALLED` was set to `True` while waiting for the output queue, then
# we need to handle the stop of the pipeline and break the loop to avoid
# propagating the batches through the pipeline and making the stop process
# slower.
if _STOP_CALLED:
self._handle_batch_on_stop(batch)
break
self._logger.debug(
f"Received batch with seq_no {batch.seq_no} from step '{batch.step_name}'"
f" from output queue: {batch}"
)
self._manage_batch_flow(batch)
if _STOP_CALLED:
self._handle_stop(write_buffer)
def _manage_batch_flow(self, batch: "_Batch") -> None:
"""Checks if the step that generated the batch has more data in its buffer to
generate a new batch. If there's data, then a new batch is sent to the step. If
the step has no data in its buffer, then the predecessors generator steps are
requested to send a new batch.
Args:
batch: The batch that was processed.
"""
assert self._batch_manager, "Batch manager is not set"
# Make sure to send the `LAST_BATCH_SENT_FLAG` to the predecessors of the convergence
# step if the batch is the last one, so they stop their processing loop even if
# they haven't received the last batch because of the routing function.
if self._is_convergence_step(batch.step_name) and batch.last_batch:
for step_name in self.dag.get_step_predecessors(batch.step_name):
self._send_last_batch_flag_to_step(step_name)
route_to, routed = self._get_successors(batch)
# Keep track of the steps that the batch was routed to
if routed:
batch.batch_routed_to = route_to
self._register_batch(batch)
step = self._get_step_from_batch(batch)
# Add the batch to the successors input buffers
for successor in route_to:
# Copy batch to avoid modifying the same reference in the batch manager
batch_to_add = batch.copy() if len(route_to) > 1 else batch
self._batch_manager.add_batch(successor, batch_to_add)
# Check if the step is a generator and if there are successors that need data
# from this step. This usually happens when the generator `batch_size` is smaller
# than the `input_batch_size` of the successor steps.
if (
step.is_generator
and step.name in self._batch_manager.step_empty_buffers(successor)
):
last_batch_sent = self._batch_manager.get_last_batch_sent(step.name)
self._send_batch_to_step(last_batch_sent.next_batch()) # type: ignore
# If successor step has enough data in its buffer to create a new batch, then
# send the batch to the step.
if new_batch := self._batch_manager.get_batch(successor):
self._send_batch_to_step(new_batch)
if step.is_generator:
return
# Step ("this", the one from which the batch was received) has enough data on its
# buffers to create a new batch
if new_batch := self._batch_manager.get_batch(step.name): # type: ignore
self._send_batch_to_step(new_batch)
else:
self._request_more_batches_if_needed(step)
self._cache()
def _register_batch(self, batch: "_Batch") -> None:
"""Registers a batch in the batch manager.
Args:
batch: The batch to register.
"""
self._batch_manager.register_batch(batch) # type: ignore
self._logger.debug(
f"Batch {batch.seq_no} from step '{batch.step_name}' registered in batch"
" manager"
)
def _get_successors(self, batch: "_Batch") -> Tuple[List[str], bool]:
"""Gets the successors and the successors to which the batch has to be routed.
Args:
batch: The batch to which the successors will be determined.
Returns:
The successors to route the batch to and whether the batch was routed using
a routing function.
"""
node = self.dag.get_step(batch.step_name)
step: "Step" = node[STEP_ATTR_NAME]
successors = list(self.dag.get_step_successors(step.name)) # type: ignore
route_to = successors
# Check if the step has a routing function to send the batch to specific steps
if routing_batch_function := node.get(ROUTING_BATCH_FUNCTION_ATTR_NAME):
route_to = routing_batch_function(batch, successors)
successors_str = ", ".join(f"'{successor}'" for successor in route_to)
self._logger.info(
f"🚏 Using '{step.name}' routing function to send batch {batch.seq_no} to steps: {successors_str}"
)
return route_to, route_to != successors
def _get_step_from_batch(self, batch: "_Batch") -> "Step":
"""Gets the `Step` instance from a batch.
Args:
batch: The batch to get the step from.
Returns:
The `Step` instance.
"""
return self.dag.get_step(batch.step_name)[STEP_ATTR_NAME]
def _request_more_batches_if_needed(self, step: "Step") -> None:
"""Request more batches to the predecessors steps of `step` if needed.
Args:
step: The step of which it has to be checked if more batches are needed from
its predecessors.
"""
empty_buffers = self._batch_manager.step_empty_buffers(step.name) # type: ignore
for previous_step_name in empty_buffers:
if previous_step_name not in self.dag.root_steps:
continue
last_batch = self._batch_manager.get_last_batch_sent(previous_step_name) # type: ignore
if last_batch is None:
continue
self._logger.debug(
f"Step '{step.name}' input buffer for step '{previous_step_name}' is"
" empty. Requesting new batch..."
)
self._send_batch_to_step(last_batch.next_batch())
def _handle_stop(self, write_buffer: "_WriteBuffer") -> None:
"""Handles the stop of the pipeline execution, which will stop the steps from
processing more batches and wait for the output queue to be empty, to not lose
any data that was already processed by the steps before the stop was called.
Args:
write_buffer: The write buffer to write the data from the leaf steps to disk.
"""
self._logger.debug("Handling stop of the pipeline execution...")
# Add the remaining batches in the input queues back to the batch manager
for step_name in self.dag:
node = self.dag.get_step(step_name)
step: "_Step" = node[STEP_ATTR_NAME]
if step.is_generator:
continue
if input_queue := node.get(INPUT_QUEUE_ATTR_NAME):
while not input_queue.empty():
batch = input_queue.get()
if batch is None:
continue
self._batch_manager.add_batch( # type: ignore
to_step=step_name, batch=batch, prepend=True
)
self._logger.debug(
f"Adding batch back to the batch manager: {batch}"
)
input_queue.put(None)
# Wait for the input queue to be empty, which means that all the steps finished
# processing the batches that were sent before the stop flag.
for step_name in self.dag:
self._wait_step_input_queue_empty(step_name)
# Consume the output queue until it's empty to not lose any data that was already
# processed by the steps before stop was called.
while not self.output_queue.empty():
batch = self.output_queue.get()
if batch is None:
continue
if batch.step_name in self.dag.leaf_steps:
write_buffer.add_batch(batch)
self._handle_batch_on_stop(batch)
self._cache()
def _handle_batch_on_stop(self, batch: "_Batch") -> None:
"""Handles a batch that was received from the output queue when the pipeline was
stopped. It will add and register the batch in the batch manager.
Args:
batch: The batch to handle.
"""
self._batch_manager.register_batch(batch) # type: ignore
step: "Step" = self.dag.get_step(batch.step_name)[STEP_ATTR_NAME]
for successor in self.dag.get_step_successors(step.name): # type: ignore
self._batch_manager.add_batch(successor, batch) # type: ignore
def _wait_step_input_queue_empty(self, step_name: str) -> Union["Queue[Any]", None]:
"""Waits for the input queue of a step to be empty.
Args:
step_name: The name of the step.
Returns:
The input queue of the step if it's not loaded or finished, `None` otherwise.
"""
if self._check_step_not_loaded_or_finished(step_name):
return None
if input_queue := self.dag.get_step(step_name).get(INPUT_QUEUE_ATTR_NAME):
while input_queue.qsize() != 0:
pass
return input_queue
def _create_shared_info_dict(self, manager: "SyncManager") -> "DictProxy[str, Any]":
"""Creates the shared information dictionary to be used by the processes.
Args:
manager: The manager to create the shared information.
Returns:
The shared information dictionary.
"""
# TODO: not very important, but we could use a different lock for each matter
return manager.dict(
**{
_STEPS_LOADED_KEY: manager.list(),
_STEPS_LOADED_LOCK_KEY: manager.Lock(),
_CUDA_LLM_DEVICE_PLACEMENT_KEY: manager.dict(**{}),
_CUDA_LLM_DEVICE_PLACEMENT_LOCK_KEY: manager.Lock(),
}
)
def _all_steps_loaded(self) -> bool:
"""Waits for all the steps to load.
Returns:
`True` if all the steps have been loaded correctly, `False` otherwise.
"""
def _update_all_steps_loaded(steps_loaded: List[str]) -> None:
with _STEPS_LOADED_LOCK:
_STEPS_LOADED.update(steps_loaded)
self._logger.info("⏳ Waiting for all the steps to load...")
previous_message = None
while not _STOP_CALLED:
with self.shared_info[_STEPS_LOADED_LOCK_KEY]:
steps_loaded = self.shared_info[_STEPS_LOADED_KEY]
num_steps_loaded = (
len(steps_loaded)
if steps_loaded != [_STEPS_LOADED_ERROR_CODE]
else 0
)
self._logger.debug(f"Steps loaded: {steps_loaded}")
message = f"⏳ Steps loaded: {num_steps_loaded}/{len(self.dag)}"
if num_steps_loaded > 0 and message != previous_message:
self._logger.info(message)
previous_message = message
if num_steps_loaded == len(self.dag):
self._logger.info("✅ All the steps have been loaded!")
_update_all_steps_loaded(steps_loaded)
return True
if steps_loaded == [_STEPS_LOADED_ERROR_CODE]:
self._logger.error("❌ Failed to load all the steps")
_update_all_steps_loaded(steps_loaded)
return False
time.sleep(2.5)
return not _STOP_CALLED
def _request_initial_batches(self) -> None:
"""Requests the initial batches to the generator steps."""
assert self._batch_manager, "Batch manager is not set"
for step in self._batch_manager._steps.values():
if batch := step.get_batch():
self._logger.debug(
f"Sending initial batch to '{step.step_name}' step: {batch}"
)
self._send_batch_to_step(batch)
for step_name in self.dag.root_steps:
seq_no = 0
if last_batch := self._batch_manager.get_last_batch(step_name):
seq_no = last_batch.seq_no + 1
batch = _Batch(seq_no=seq_no, step_name=step_name, last_batch=self._dry_run)
self._logger.debug(
f"Requesting initial batch to '{step_name}' generator step: {batch}"
)
self._send_batch_to_step(batch)
def _send_batch_to_step(self, batch: "_Batch") -> None:
"""Sends a batch to the input queue of a step.
Args:
batch: The batch to send.
"""
self._logger.debug(
f"Setting batch {batch.seq_no} as last batch sent to '{batch.step_name}': {batch}"
)
self._batch_manager.set_last_batch_sent(batch) # type: ignore
self._logger.debug(
f"Sending batch {batch.seq_no} to step '{batch.step_name}': {batch}"
)
input_queue = self.dag.get_step(batch.step_name)[INPUT_QUEUE_ATTR_NAME]
input_queue.put(batch)
def _is_convergence_step(self, step_name: str) -> None:
"""Checks if a step is a convergence step.
Args:
step_name: The name of the step.
"""
return self.dag.get_step(step_name).get(CONVERGENCE_STEP_ATTR_NAME)
def _send_last_batch_flag_to_step(self, step_name: str) -> None:
"""Sends the `LAST_BATCH_SENT_FLAG` to a step to stop processing batches.
Args:
step_name: The name of the step.
"""
batch = self._batch_manager.get_last_batch_sent(step_name) # type: ignore
if batch and batch.last_batch:
return
self._logger.debug(
f"Sending `LAST_BATCH_SENT_FLAG` to '{step_name}' step to stop processing"
" batches..."
)
input_queue = self.dag.get_step(step_name)[INPUT_QUEUE_ATTR_NAME]
input_queue.put(LAST_BATCH_SENT_FLAG)
self._batch_manager.set_last_batch_flag_sent_to(step_name) # type: ignore
def _run_steps_in_loop(
self,
pool: "Pool",
manager: "SyncManager",
output_queue: "Queue[_Batch]",
shared_info: "DictProxy[str, Any]",
) -> None:
"""Using the `pool`, runs the steps in the DAG in an infinite loop waiting for
input batches and sending the output batches to the `output_queue`.
Each `Step` is wrapped in a `_ProcessWrapper`, which will handle the lifecycle of
the `Step` and the communication with the `input_queue` and `output_queue`. The
`_ProcessWrapper.run` method is the target function of the process.
Args:
pool: The pool of processes.
manager: The manager to create the queues.
output_queue: The queue to send the output batches.
shared_info: The shared information between the processes.
"""
for step_name in self.dag:
step: "Step" = self.dag.get_step(step_name)[STEP_ATTR_NAME]
input_queue = manager.Queue()
self.dag.set_step_attr(step.name, INPUT_QUEUE_ATTR_NAME, input_queue)
# Set `pipeline` to `None` as in some Python environments the pipeline is not
# picklable and it will raise an error when trying to send the step to the process.
# `TypeError: cannot pickle 'code' object`
step.pipeline = None
process_wrapper = _ProcessWrapper(
step=step,
input_queue=input_queue,
output_queue=output_queue,
shared_info=shared_info,
dry_run=self._dry_run,
)
pool.apply_async(
process_wrapper.run,
callback=self._finished_callback,
error_callback=self._error_callback,
) # type: ignore
def _error_callback(self, e: BaseException) -> None:
"""Error callback that will be called when an error occurs in a `Step` process.
Args:
e: The exception raised by the process.
"""
global _SUBPROCESS_EXCEPTION
# First we check that the exception is a `_ProcessWrapperException`, otherwise, we
# print it out and stop the pipeline, since some errors may be unhandled
if not isinstance(e, _ProcessWrapperException):
self._logger.error(f"❌ Failed with an unhandled exception: {e}")
self._stop()
return
if e.is_load_error:
self._logger.error(f"❌ Failed to load step '{e.step.name}': {e.message}")
with self.shared_info[_STEPS_LOADED_LOCK_KEY]:
self.shared_info[_STEPS_LOADED_KEY] = [_STEPS_LOADED_ERROR_CODE]
_SUBPROCESS_EXCEPTION = e.subprocess_exception
_SUBPROCESS_EXCEPTION.__traceback__ = tblib.Traceback.from_string(
e.formatted_traceback
).as_traceback()
return
# If the step is global, is not in the last trophic level and has no successors,
# then we can ignore the error and continue executing the pipeline
if (
e.step.is_global
and not self.dag.step_in_last_trophic_level(e.step.name)
and list(self.dag.get_step_successors(e.step.name)) == []
):
self._logger.error(
f"✋ An error occurred when running global step '{e.step.name}' with no"
" successors and not in the last trophic level. Pipeline execution can"
f" continue. Error will be ignored."
)
self._logger.error(f"Subprocess traceback:\n\n{e.formatted_traceback}")
return
# Global step with successors failed
self._logger.error(f"An error occurred in global step '{e.step.name}'")
self._logger.error(f"Subprocess traceback:\n\n{e.formatted_traceback}")
self._cache()
self._stop()
def _finished_callback(self, step_name: str) -> None:
"""Callback that will be called when a `Step` process finishes.
Args:
step_name: The name of the step that finished.
"""
with _STEPS_FINISHED_LOCK:
_STEPS_FINISHED.add(step_name)
def _check_step_not_loaded_or_finished(self, step_name: str) -> bool:
"""Checks if a step is not loaded or already finished.
Args:
step_name: The name of the step.
Returns:
`True` if the step is not loaded or already finished, `False` otherwise.
"""
with _STEPS_LOADED_LOCK:
if step_name not in _STEPS_LOADED:
return True
with _STEPS_FINISHED_LOCK:
if step_name in _STEPS_FINISHED:
return True
return False
def _stop(
self, manager: Optional["SyncManager"] = None, pool: Optional["Pool"] = None
) -> None:
"""Stops the pipeline execution. It will first send `None` to the input queues
of all the steps and then wait until the output queue is empty i.e. all the steps
finished processing the batches that were sent before the stop flag. Then it will
send `None` to the output queue to notify the pipeline to stop."""
global _STOP_CALLED
with _STOP_CALLED_LOCK:
if _STOP_CALLED:
global _STOP_CALLS
_STOP_CALLS += 1
# if _STOP_CALLS == 1:
# self._logger.warning(
# "🛑 Stop has already been called. Ignoring subsequent calls and waiting"
# " for the pipeline to finish..."
# )
if _STOP_CALLS == 1:
self._logger.warning(
"🛑 Press again to force the pipeline to stop."
)
elif _STOP_CALLS > 1:
self._logger.warning("🛑 Forcing pipeline interruption.")
import gc
import sys
if manager:
manager.shutdown()
if pool:
pool.close()
pool.terminate()
gc.collect()
sys.exit(1)
return
_STOP_CALLED = True
self._logger.debug(f"Steps loaded before calling `stop`: {_STEPS_LOADED}")
self._logger.info(
"🛑 Stopping pipeline. Waiting for steps to finish processing batches..."
)
self._logger.debug("Sending `None` to the output queue to notify stop...")
self.output_queue.put(None)
def _handle_keyboard_interrupt(
self, manager: Optional["SyncManager"] = None, pool: Optional["Pool"] = None
) -> None:
"""Handles KeyboardInterrupt signal sent during the Pipeline.run method.
It will try to call self._stop (if the pipeline didn't started yet, it won't
have any effect), and if the pool is already started, will close it before exiting
the program.
"""
def signal_handler(signumber: int, frame: Any) -> None:
self._stop(manager=manager, pool=pool)
signal.signal(signal.SIGINT, signal_handler)