Source code for ccflow_celery.evaluators

import graphlib
from logging import getLogger
from typing import Any, Optional

from ccflow import EvaluatorBase, GenericResult
from ccflow.callable import ModelEvaluationContext
from ccflow.evaluators.common import get_dependency_graph
from celery import group
from pydantic import Field, PrivateAttr
from typing_extensions import override

from .app import CeleryApp

__all__ = (
    "CeleryEvaluator",
    "CeleryGraphEvaluator",
)

_log = getLogger(__name__)


def _resolve_origin(cls):
    """Resolve parameterized generics (e.g. GenericContext[int]) to their origin class."""
    meta = getattr(cls, "__pydantic_generic_metadata__", None)
    if meta and meta.get("origin") is not None:
        return meta["origin"]
    return cls


def _model_fqn(model) -> str:
    """Get the fully qualified class name of a model."""
    cls = _resolve_origin(type(model))
    return f"{cls.__module__}.{cls.__qualname__}"


def _context_fqn(context) -> str:
    """Get the fully qualified class name of a context."""
    cls = _resolve_origin(type(context))
    return f"{cls.__module__}.{cls.__qualname__}"


def _serialize_model(model):
    """Serialize a model to (fqn, config_dict) for Celery task dispatch."""
    fqn = _model_fqn(model)
    config = model.model_dump(exclude={"type_"}) if hasattr(model, "model_dump") else {}
    return fqn, config


def _serialize_context(ctx):
    """Serialize a context to (fqn, config_dict) for Celery task dispatch."""
    fqn = _context_fqn(ctx)
    config = ctx.model_dump(exclude={"type_"}) if hasattr(ctx, "model_dump") else {}
    return fqn, config


def _unwrap_eval_ctx(eval_ctx):
    """Unwrap evaluator-wrapping layers to get the actual model and context.

    The dependency graph may wrap nodes as ModelEvaluationContext(model=evaluator,
    context=ModelEvaluationContext(model=actual_model, context=actual_context)).
    """
    inner = eval_ctx
    while isinstance(inner.context, ModelEvaluationContext):
        inner = inner.context
    return inner.model, inner.context


[docs] class CeleryEvaluator(EvaluatorBase): """Evaluator that dispatches model execution to Celery workers. Serializes the model and context, submits as a Celery task, and waits for the result. The model must be reconstructable from its Pydantic config dump. """ app: CeleryApp = Field(default_factory=CeleryApp) timeout: Optional[float] = Field(default=300.0, description="Timeout in seconds for waiting on task result") task_name: str = Field(default="ccflow_celery.tasks.execute_model_task", description="Registered Celery task name") _celery_app: Any = PrivateAttr(default=None)
[docs] def _get_celery_app(self): if self._celery_app is None: self._celery_app = self.app.get_app() return self._celery_app
[docs] @override def __call__(self, context: ModelEvaluationContext) -> Any: model = context.model ctx = context.context # For __deps__ calls or non-__call__ calls, execute locally if hasattr(context, "fn") and context.fn != "__call__": return context() # Serialize model and context model_class, model_config = _serialize_model(model) context_class, context_config = _serialize_context(ctx) celery_app = self._get_celery_app() _log.info("Submitting Celery task: %s(%s)", model_class, context_class) # Send task to Celery task = celery_app.send_task( self.task_name, args=[model_class, model_config, context_class, context_config], ) # Wait for result result_data = task.get(timeout=self.timeout) _log.info("Celery task completed: %s", task.id) # Reconstruct result if hasattr(model, "result_type"): result_cls = model.result_type try: return result_cls(**result_data) except Exception: pass return GenericResult(value=result_data)
[docs] class CeleryGraphEvaluator(EvaluatorBase): """Evaluator that parallelizes DAG execution via Celery. Builds the dependency graph (like GraphEvaluator), then submits independent nodes as parallel Celery tasks using Celery groups. Nodes that depend on other nodes wait for their dependencies first. """ app: CeleryApp = Field(default_factory=CeleryApp) timeout: Optional[float] = Field(default=600.0, description="Timeout for the entire graph execution") task_name: str = Field(default="ccflow_celery.tasks.execute_model_task", description="Registered Celery task name") _celery_app: Any = PrivateAttr(default=None) _is_evaluating: bool = PrivateAttr(False)
[docs] def _get_celery_app(self): if self._celery_app is None: self._celery_app = self.app.get_app() return self._celery_app
[docs] @override def __call__(self, context: ModelEvaluationContext) -> Any: # Avoid re-entrancy if self._is_evaluating: return context() self._is_evaluating = True root_result = None try: graph = get_dependency_graph(context) ts = graphlib.TopologicalSorter(graph.graph) celery_app = self._get_celery_app() results_map = {} # Process nodes level by level for maximum parallelism ts.prepare() while ts.is_active(): ready_nodes = list(ts.get_ready()) # Submit all ready nodes in parallel tasks = [] for key in ready_nodes: eval_ctx = graph.ids[key] model, ctx = _unwrap_eval_ctx(eval_ctx) model_class, model_config = _serialize_model(model) context_class, context_config = _serialize_context(ctx) sig = celery_app.signature( self.task_name, args=[model_class, model_config, context_class, context_config], ) tasks.append((key, sig)) if len(tasks) == 1: # Single task — submit directly key, sig = tasks[0] async_result = sig.apply_async() result = async_result.get(timeout=self.timeout) results_map[key] = result if key == graph.root_id: root_result = result ts.done(key) else: # Multiple tasks — use Celery group for parallelism keys = [k for k, _ in tasks] sigs = [s for _, s in tasks] group_result = group(sigs).apply_async() group_results = group_result.get(timeout=self.timeout) for key, result in zip(keys, group_results): results_map[key] = result if key == graph.root_id: root_result = result ts.done(key) finally: self._is_evaluating = False # Reconstruct result for the root node if root_result is not None and hasattr(context, "model"): model = context.model if hasattr(model, "result_type"): result_cls = model.result_type try: return result_cls(**root_result) except Exception: pass return GenericResult(value=root_result)