diff --git a/README.md b/README.md index b39a6ba..48e9f20 100644 --- a/README.md +++ b/README.md @@ -120,19 +120,29 @@ parallel = MARS::Workflows::Parallel.new( ### Gates -Create conditional branching in your workflows: +Gates act as guards that either let the workflow continue or divert to a fallback path: ```ruby gate = MARS::Gate.new( - "Decision Gate", - condition: ->(input) { input[:score] > 0.5 ? :success : :failure }, - branches: { - success: success_workflow, + "Validation Gate", + check: ->(input) { :failure unless input[:score] > 0.5 }, + fallbacks: { failure: failure_workflow } ) ``` +Control halt scope — `:local` (default) stops only the parent workflow, `:global` propagates to the root: + +```ruby +gate = MARS::Gate.new( + "Critical Gate", + check: ->(input) { :error unless input[:valid] }, + fallbacks: { error: error_workflow }, + halt_scope: :global +) +``` + ### Visualization Generate Mermaid diagrams to visualize your workflows: diff --git a/examples/complex_llm_workflow/generator.rb b/examples/complex_llm_workflow/generator.rb index 4ea2c77..26d2e62 100755 --- a/examples/complex_llm_workflow/generator.rb +++ b/examples/complex_llm_workflow/generator.rb @@ -91,8 +91,8 @@ def tools ) gate = MARS::Gate.new( - condition: ->(input) { input.split.length < 10 ? :success : :failure }, - branches: { + check: ->(input) { :failure unless input.split.length < 10 }, + fallbacks: { failure: error_workflow } ) diff --git a/examples/complex_workflow/generator.rb b/examples/complex_workflow/generator.rb index 62adc6b..dc54fc6 100755 --- a/examples/complex_workflow/generator.rb +++ b/examples/complex_workflow/generator.rb @@ -46,8 +46,8 @@ class Agent5 < MARS::Agent # Create the gate that decides between exit or continue gate = MARS::Gate.new( - condition: ->(input) { input[:result] }, - branches: { + check: ->(input) { input[:result] }, + fallbacks: { warning: sequential_workflow, error: parallel_workflow } diff --git a/examples/simple_workflow/generator.rb b/examples/simple_workflow/generator.rb index 4dd01c4..2f8242c 100755 --- a/examples/simple_workflow/generator.rb +++ b/examples/simple_workflow/generator.rb @@ -26,8 +26,8 @@ class Agent3 < MARS::Agent # Create the gate that decides between exit or continue gate = MARS::Gate.new( - condition: ->(input) { input[:result] }, - branches: { + check: ->(input) { input[:result] }, + fallbacks: { success: success_workflow } ) diff --git a/lib/mars/gate.rb b/lib/mars/gate.rb index 8a2bafd..5ce7502 100644 --- a/lib/mars/gate.rb +++ b/lib/mars/gate.rb @@ -2,21 +2,51 @@ module MARS class Gate < Runnable - def initialize(name = "Gate", condition:, branches:, **kwargs) + class << self + def check(&block) + @check_block = block + end + + attr_reader :check_block + + def fallback(key, runnable) + fallbacks_map[key] = runnable + end + + def fallbacks_map + @fallbacks_map ||= {} + end + + def halt_scope(scope = nil) + scope ? @halt_scope = scope : @halt_scope + end + end + + def initialize(name = "Gate", check: nil, fallbacks: nil, halt_scope: nil, **kwargs) super(name: name, **kwargs) - @condition = condition - @branches = branches + @check = check || self.class.check_block + @fallbacks = fallbacks || self.class.fallbacks_map + @halt_scope = halt_scope || self.class.halt_scope || :local end def run(input) - result = condition.call(input) + result = check.call(input) + + return input unless result + + branch = fallbacks[result] + raise ArgumentError, "No fallback registered for #{result.inspect}" unless branch - branches[result] || input + Halt.new(resolve_branch(branch).run(input), scope: @halt_scope) end private - attr_reader :condition, :branches + attr_reader :check, :fallbacks + + def resolve_branch(branch) + branch.is_a?(Class) ? branch.new : branch + end end end diff --git a/lib/mars/halt.rb b/lib/mars/halt.rb new file mode 100644 index 0000000..043e80e --- /dev/null +++ b/lib/mars/halt.rb @@ -0,0 +1,15 @@ +# frozen_string_literal: true + +module MARS + class Halt + attr_reader :result, :scope + + def initialize(result, scope: :local) + @result = result + @scope = scope + end + + def local? = scope == :local + def global? = scope == :global + end +end diff --git a/lib/mars/rendering/graph/gate.rb b/lib/mars/rendering/graph/gate.rb index 2defd66..ccc5099 100644 --- a/lib/mars/rendering/graph/gate.rb +++ b/lib/mars/rendering/graph/gate.rb @@ -10,8 +10,8 @@ def to_graph(builder, parent_id: nil, value: nil) builder.add_node(node_id, name, Node::GATE) builder.add_edge(parent_id, node_id, value) - sink_nodes = branches.map do |condition_result, branch| - branch.to_graph(builder, parent_id: node_id, value: condition_result) + sink_nodes = fallbacks.map do |fallback_key, branch| + branch.to_graph(builder, parent_id: node_id, value: fallback_key) end sink_nodes.flatten diff --git a/lib/mars/workflows/parallel.rb b/lib/mars/workflows/parallel.rb index ef8f3f6..99667cc 100644 --- a/lib/mars/workflows/parallel.rb +++ b/lib/mars/workflows/parallel.rb @@ -12,8 +12,23 @@ def initialize(name, steps:, aggregator: nil, **kwargs) def run(input) errors = [] - results = Async do |workflow| - tasks = @steps.map do |step| + results = execute_steps(input, errors) + + raise AggregateError, errors if errors.any? + + has_global_halt = results.any? { |r| r.is_a?(Halt) && r.global? } + unwrapped = results.map { |r| r.is_a?(Halt) ? r.result : r } + result = aggregator.run(unwrapped) + has_global_halt ? Halt.new(result, scope: :global) : result + end + + private + + attr_reader :steps, :aggregator + + def execute_steps(input, errors) + Async do |workflow| + tasks = steps.map do |step| workflow.async do step.run(input) rescue StandardError => e @@ -23,15 +38,7 @@ def run(input) tasks.map(&:wait) end.result - - raise AggregateError, errors if errors.any? - - aggregator.run(results) end - - private - - attr_reader :steps, :aggregator end end end diff --git a/lib/mars/workflows/sequential.rb b/lib/mars/workflows/sequential.rb index df673c6..9f7a6ed 100644 --- a/lib/mars/workflows/sequential.rb +++ b/lib/mars/workflows/sequential.rb @@ -11,14 +11,14 @@ def initialize(name, steps:, **kwargs) def run(input) @steps.each do |step| - result = step.run(input) - - if result.is_a?(Runnable) - input = result.run(input) - break - else - input = result - end + input = step.run(input) + + next unless input.is_a?(Halt) + + return input if input.global? + + input = input.result + break end input diff --git a/spec/mars/aggregator_spec.rb b/spec/mars/aggregator_spec.rb index 294ec93..408803e 100644 --- a/spec/mars/aggregator_spec.rb +++ b/spec/mars/aggregator_spec.rb @@ -2,7 +2,7 @@ RSpec.describe MARS::Aggregator do describe "#run" do - context "when called without a block" do + context "when called without an operation" do let(:aggregator) { described_class.new } it "returns the input as is" do @@ -11,10 +11,10 @@ end end - context "when initialized with a block operation" do + context "when initialized with an operation" do let(:aggregator) { described_class.new("Aggregator", operation: lambda(&:join)) } - it "executes the block and returns its value" do + it "executes the operation and returns its value" do result = aggregator.run(%w[a b c]) expect(result).to eq("abc") end diff --git a/spec/mars/gate_spec.rb b/spec/mars/gate_spec.rb index d8cf33d..cea8849 100644 --- a/spec/mars/gate_spec.rb +++ b/spec/mars/gate_spec.rb @@ -1,83 +1,127 @@ # frozen_string_literal: true RSpec.describe MARS::Gate do - describe "#run" do - let(:gate) { described_class.new("TestGate", condition: condition, branches: branches) } - - context "with simple boolean condition" do - let(:condition) { ->(input) { input > 5 } } - let(:false_branch) { instance_spy(MARS::Runnable) } - let(:branches) { { false => false_branch } } - - it "returns the input when no branch matches" do - result = gate.run(10) - expect(result).to eq(10) + let(:fallback_step) do + Class.new(MARS::Runnable) do + def run(input) + "fallback: #{input}" end + end.new + end - it "returns the false branch when condition is false" do - result = gate.run(3) + let(:error_step) do + Class.new(MARS::Runnable) do + def run(input) + "error: #{input}" + end + end.new + end - expect(result).to eq(false_branch) + describe "#run" do + context "with constructor-based configuration" do + it "passes through when check returns falsy" do + gate = described_class.new( + "PassGate", + check: ->(_input) {}, + fallbacks: { fail: fallback_step } + ) + + expect(gate.run("hello")).to eq("hello") end - it "does not run the false branch when condition is false" do - gate.run(3) + it "halts with fallback result when check returns a key" do + gate = described_class.new( + "FailGate", + check: ->(_input) { :fail }, + fallbacks: { fail: fallback_step } + ) - expect(false_branch).not_to have_received(:run) + result = gate.run("hello") + expect(result).to be_a(MARS::Halt) + expect(result.result).to eq("fallback: hello") end - end - context "with string-based condition" do - let(:condition) { ->(input) { input.length > 5 ? "long" : "short" } } - let(:long_branch) { instance_spy(MARS::Runnable) } - let(:short_branch) { instance_spy(MARS::Runnable) } - let(:branches) { { "long" => long_branch, "short" => short_branch } } + it "raises when check returns an unregistered key" do + gate = described_class.new( + "BadGate", + check: ->(_input) { :unknown }, + fallbacks: { fail: fallback_step } + ) - it "routes to long branch for long strings" do - result = gate.run("longstring") - - expect(result).to eq(long_branch) + expect { gate.run("hello") }.to raise_error(ArgumentError, /No fallback registered for :unknown/) end - it "routes to short branch for short strings" do - result = gate.run("hi") - - expect(result).to eq(short_branch) + it "selects among multiple fallbacks" do + gate = described_class.new( + "MultiFallback", + check: ->(input) { input[:error_type] }, + fallbacks: { timeout: fallback_step, auth: error_step } + ) + + input = { error_type: :auth } + result = gate.run(input) + expect(result).to be_a(MARS::Halt) + expect(result.result).to eq("error: #{input}") end end - context "with complex condition logic" do - let(:condition) do - lambda do |input| - case input - when 0..10 then "low" - when 11..50 then "medium" - else "high" + context "with class-level DSL" do + let(:fallback_cls) do + Class.new(MARS::Runnable) do + def run(input) + "handled: #{input}" end end end - let(:low_branch) { instance_spy(MARS::Runnable) } - let(:medium_branch) { instance_spy(MARS::Runnable) } - let(:high_branch) { instance_spy(MARS::Runnable) } - let(:branches) { { "low" => low_branch, "medium" => medium_branch, "high" => high_branch } } + it "uses check and fallback DSL" do + cls = fallback_cls + gate_class = Class.new(described_class) do + check { |input| :invalid if input.length > 5 } + fallback :invalid, cls + end - it "routes to low branch" do - result = gate.run(5) + gate = gate_class.new("DSLGate") + expect(gate.run("hi")).to eq("hi") + expect(gate.run("longstring").result).to eq("handled: longstring") + end - expect(result).to eq(low_branch) + it "supports halt_scope DSL" do + cls = fallback_cls + gate_class = Class.new(described_class) do + check { |_input| :fail } + fallback :fail, cls + halt_scope :global + end + + result = gate_class.new("GlobalGate").run("test") + expect(result).to be_a(MARS::Halt) + expect(result).to be_global end + end - it "routes to medium branch" do - result = gate.run(25) + context "with halt scope" do + it "defaults to local scope" do + gate = described_class.new( + "LocalGate", + check: ->(_input) { :fail }, + fallbacks: { fail: fallback_step } + ) - expect(result).to eq(medium_branch) + result = gate.run("hello") + expect(result).to be_local end - it "routes to high branch" do - result = gate.run(100) + it "respects constructor halt_scope" do + gate = described_class.new( + "GlobalGate", + check: ->(_input) { :fail }, + fallbacks: { fail: fallback_step }, + halt_scope: :global + ) - expect(result).to eq(high_branch) + result = gate.run("hello") + expect(result).to be_global end end end diff --git a/spec/mars/halt_spec.rb b/spec/mars/halt_spec.rb new file mode 100644 index 0000000..da3b9c5 --- /dev/null +++ b/spec/mars/halt_spec.rb @@ -0,0 +1,26 @@ +# frozen_string_literal: true + +RSpec.describe MARS::Halt do + describe "#scope" do + it "defaults to :local" do + halt = described_class.new("result") + expect(halt.scope).to eq(:local) + expect(halt).to be_local + expect(halt).not_to be_global + end + + it "can be set to :global" do + halt = described_class.new("result", scope: :global) + expect(halt.scope).to eq(:global) + expect(halt).to be_global + expect(halt).not_to be_local + end + end + + describe "#result" do + it "stores the result" do + halt = described_class.new("hello") + expect(halt.result).to eq("hello") + end + end +end diff --git a/spec/mars/workflows/parallel_spec.rb b/spec/mars/workflows/parallel_spec.rb index ab05d25..e0bcb38 100644 --- a/spec/mars/workflows/parallel_spec.rb +++ b/spec/mars/workflows/parallel_spec.rb @@ -77,6 +77,51 @@ def run(_input) expect(workflow.run(42)).to eq([]) end + it "unwraps local halts and returns plain result" do + gate = MARS::Gate.new( + "LocalBranch", + check: ->(_input) { :branch }, + fallbacks: { + branch: Class.new(MARS::Runnable) do + def run(input) + "branched:#{input}" + end + end.new + } + ) + add_five = add_step_class.new(5) + + workflow = described_class.new("halt_workflow", steps: [gate, add_five]) + + result = workflow.run(10) + # Local halts are unwrapped, aggregated as plain values + expect(result).not_to be_a(MARS::Halt) + expect(result).to eq(["branched:10", 15]) + end + + it "propagates global halt to parent workflow" do + gate = MARS::Gate.new( + "GlobalBranch", + check: ->(_input) { :branch }, + fallbacks: { + branch: Class.new(MARS::Runnable) do + def run(input) + "branched:#{input}" + end + end.new + }, + halt_scope: :global + ) + add_five = add_step_class.new(5) + + workflow = described_class.new("halt_workflow", steps: [gate, add_five]) + + result = workflow.run(10) + expect(result).to be_a(MARS::Halt) + expect(result).to be_global + expect(result.result).to eq(["branched:10", 15]) + end + it "propagates errors from steps" do add_step = add_step_class.new(5) error_step = error_step_class.new("Step failed", "error_step_one") diff --git a/spec/mars/workflows/sequential_spec.rb b/spec/mars/workflows/sequential_spec.rb index 45c0783..01c3e44 100644 --- a/spec/mars/workflows/sequential_spec.rb +++ b/spec/mars/workflows/sequential_spec.rb @@ -62,6 +62,108 @@ def run(_input) expect(workflow.run(42)).to eq(42) end + it "halts locally when a gate triggers with local scope" do + add_five = add_step_class.new(5) + gate = MARS::Gate.new( + "LocalGate", + check: ->(_input) { :branch }, + fallbacks: { + branch: Class.new(MARS::Runnable) do + def run(input) + "branched:#{input}" + end + end.new + } + ) + multiply_three = multiply_step_class.new(3) + + workflow = described_class.new("halt_workflow", steps: [add_five, gate, multiply_three]) + + # 10 + 5 = 15, gate branches -> "branched:15", multiply_three is never reached + # Local halt is consumed — returns plain value + result = workflow.run(10) + expect(result).to eq("branched:15") + expect(result).not_to be_a(MARS::Halt) + end + + it "propagates global halt without unwrapping" do + add_five = add_step_class.new(5) + gate = MARS::Gate.new( + "GlobalGate", + check: ->(_input) { :branch }, + fallbacks: { + branch: Class.new(MARS::Runnable) do + def run(input) + "branched:#{input}" + end + end.new + }, + halt_scope: :global + ) + multiply_three = multiply_step_class.new(3) + + workflow = described_class.new("halt_workflow", steps: [add_five, gate, multiply_three]) + + result = workflow.run(10) + expect(result).to be_a(MARS::Halt) + expect(result).to be_global + expect(result.result).to eq("branched:15") + end + + it "propagates global halt through nested sequential workflows" do + inner_gate = MARS::Gate.new( + "InnerGate", + check: ->(_input) { :stop }, + fallbacks: { + stop: Class.new(MARS::Runnable) do + def run(input) + "stopped:#{input}" + end + end.new + }, + halt_scope: :global + ) + + inner = described_class.new("inner", steps: [inner_gate]) + after_inner = add_step_class.new(100) + outer = described_class.new("outer", steps: [inner, after_inner]) + + result = outer.run(1) + # Global halt propagates through both sequential levels + expect(result).to be_a(MARS::Halt) + expect(result.result).to eq("stopped:1") + end + + it "consumes local halt — outer workflow continues" do + inner_gate = MARS::Gate.new( + "InnerGate", + check: ->(_input) { :stop }, + fallbacks: { + stop: Class.new(MARS::Runnable) do + def run(input) + "stopped:#{input}" + end + end.new + } + # default :local scope + ) + + inner = described_class.new("inner", steps: [inner_gate]) + + # Inner halts locally -> returns "stopped:1" as plain value + string_step = Class.new(MARS::Runnable) do + def run(input) + "after:#{input}" + end + end.new + + outer = described_class.new("outer", steps: [inner, string_step]) + + result = outer.run(1) + expect(result).to eq("after:stopped:1") + expect(result).not_to be_a(MARS::Halt) + end + it "propagates errors from steps" do add_step = add_step_class.new(5) error_step = error_step_class.new("Step failed")