Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion examples/config.ru
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ require_relative 'service_communication'
require_relative 'typed_handlers'
require_relative 'typed_handlers_sorbet'
require_relative 'service_configuration'
require_relative 'middleware'

endpoint = Restate.endpoint(
Greeter,
Expand All @@ -26,7 +27,13 @@ endpoint = Restate.endpoint(
Worker, FanOut,
TicketService,
EventService,
OrderProcessor
OrderProcessor,
MiddlewareDemo
)

# Register handler-level middleware (Sidekiq-style)
endpoint.use(LoggingMiddleware)
endpoint.use(TenantMiddleware)
endpoint.use(MetricsMiddleware, prefix: 'examples')

run endpoint.app
98 changes: 98 additions & 0 deletions examples/middleware.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# typed: true
# frozen_string_literal: true

#
# Example: Handler Middleware
#
# Middleware wraps every handler invocation — like Sidekiq middleware.
# Use it for tracing, metrics, logging, error reporting, tenant isolation, etc.
#
# Each middleware is a class with a `call(handler, ctx)` method that uses
# `yield` to invoke the next middleware or the handler itself. Constructor
# args are passed via `endpoint.use(Klass, args...)`.
#
# Available in `call`:
# handler.name — handler method name ("greet")
# handler.service_tag.name — service name ("Greeter")
# handler.service_tag.kind — "service", "object", or "workflow"
# handler.kind — nil, "exclusive", "shared", or "workflow"
# ctx.request.id — invocation ID
# ctx.request.headers — invocation headers (durable, from caller)
#
# Try it:
# curl localhost:8080/MiddlewareDemo/greet \
# -H 'content-type: application/json' \
# -H 'x-team-id: acme-corp' \
# -d '"World"'

require 'restate'

# ── Logging middleware ──
# Logs every handler invocation with timing.
class LoggingMiddleware
# @param handler [Restate::Handler]
# @param ctx [Restate::Context]
def call(handler, ctx) # rubocop:disable Metrics/AbcSize
service = handler.service_tag.name
name = handler.name
invocation_id = ctx.request.id
puts "[#{service}/#{name}] Starting (invocation: #{invocation_id})"
start = Process.clock_gettime(Process::CLOCK_MONOTONIC)
result = yield
duration = Process.clock_gettime(Process::CLOCK_MONOTONIC) - start
puts "[#{service}/#{name}] Completed in #{(duration * 1000).round(1)}ms"
result
rescue StandardError => e
puts "[#{service}/#{name}] Failed: #{e.class} — #{e.message}"
raise
end
end

# ── Tenant context middleware ──
# Extracts team_id from invocation headers and stores it in fiber-local storage.
# Downstream code can read Thread.current[:team_id] for tenant isolation.
class TenantMiddleware
# @param handler [Restate::Handler]
# @param ctx [Restate::Context]
def call(_handler, ctx)
Thread.current[:team_id] = ctx.request.headers['x-team-id']
yield
ensure
Thread.current[:team_id] = nil
end
end

# ── Metrics middleware (with constructor args) ──
# Demonstrates middleware with configuration. In production you'd use
# a real Prometheus client or StatsD.
class MetricsMiddleware
def initialize(prefix: 'restate')
@prefix = prefix
@counts = Hash.new(0)
end

# @param handler [Restate::Handler]
# @param _ctx [Restate::Context]
def call(handler, _ctx)
key = "#{@prefix}.#{handler.service_tag.name}.#{handler.name}"
@counts[key] += 1
yield
end

def counts
@counts.dup
end
end

# ── Service ──

class MiddlewareDemo < Restate::Service
handler :greet, input: String, output: String
# @param ctx [Restate::Context]
# @param name [String]
# @return [String]
def greet(ctx, name)
team = Thread.current[:team_id] || 'unknown'
ctx.run_sync('build-greeting') { "Hello, #{name}! (team: #{team})" }
end
end
68 changes: 68 additions & 0 deletions lib/restate/endpoint.rb
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@ class Endpoint
sig { returns(T.nilable(String)) }
attr_accessor :protocol

sig { returns(T::Array[T.untyped]) }
attr_reader :middleware

sig { void }
def initialize
@services = T.let({}, T::Hash[String, T.untyped])
@protocol = T.let(nil, T.nilable(String))
@identity_keys = T.let([], T::Array[String])
@middleware = T.let([], T::Array[T.untyped])
end

# Bind one or more services to this endpoint.
Expand Down Expand Up @@ -59,6 +63,70 @@ def identity_key(key)
self
end

# Add handler-level middleware.
#
# Middleware wraps every handler invocation with access to the handler metadata
# and context. Use it for tracing, metrics, logging, error reporting, etc.
#
# A middleware is a class whose instances respond to +call(handler, ctx)+.
# Use +yield+ inside +call+ to invoke the next middleware or the handler.
# The return value of +yield+ is the handler's return value.
#
# This follows the same pattern as {https://github.com/sidekiq/sidekiq/wiki/Middleware Sidekiq middleware}.
#
# @example OpenTelemetry tracing
# class OpenTelemetryMiddleware
# def call(handler, ctx)
# tracer.in_span(handler.name, attributes: {
# 'restate.service' => handler.service_tag.name,
# 'restate.invocation_id' => ctx.request.id
# }) do
# yield
# end
# end
# end
# endpoint.use(OpenTelemetryMiddleware)
#
# @example Metrics
# class MetricsMiddleware
# def call(handler, ctx)
# start = Process.clock_gettime(Process::CLOCK_MONOTONIC)
# result = yield
# duration = Process.clock_gettime(Process::CLOCK_MONOTONIC) - start
# StatsD.timing("restate.handler.#{handler.name}", duration)
# result
# end
# end
# endpoint.use(MetricsMiddleware)
#
# @example Middleware with configuration
# class AuthMiddleware
# def initialize(api_key:)
# @api_key = api_key
# end
#
# def call(handler, ctx)
# raise Restate::TerminalError.new('unauthorized', status_code: 401) unless valid?(ctx)
# yield
# end
# end
# endpoint.use(AuthMiddleware, api_key: 'secret')
#
# @param klass [Class] middleware class (will be instantiated by the SDK)
# @param args [Array] positional arguments for the middleware constructor
# @param kwargs [Hash] keyword arguments for the middleware constructor
# @return [self]
sig { params(klass: T.untyped, args: T.untyped, kwargs: T.untyped).returns(T.self_type) }
def use(klass, *args, **kwargs)
instance = if kwargs.empty?
klass.new(*args)
else
klass.new(*args, **kwargs)
end
@middleware << instance
self
end

# Build and return the Rack-compatible application.
sig { returns(T.untyped) }
def app
Expand Down
33 changes: 23 additions & 10 deletions lib/restate/handler.rb
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,32 @@ def initialize(accept: 'application/json', content_type: 'application/json',

# Invoke a handler with the context and raw input bytes.
# The context is passed as the first argument to every handler.
# Middleware (if any) wraps the handler call.
# Returns raw output bytes.
sig { params(handler: T.untyped, ctx: T.untyped, in_buffer: String).returns(String) }
def invoke_handler(handler:, ctx:, in_buffer:)
if handler.arity == 2
begin
in_arg = handler.handler_io.input_serde.deserialize(in_buffer)
rescue StandardError => e
Kernel.raise TerminalError, "Unable to parse input argument: #{e.message}"
sig do
params(handler: T.untyped, ctx: T.untyped, in_buffer: String,
middleware: T::Array[T.untyped]).returns(String)
end
def invoke_handler(handler:, ctx:, in_buffer:, middleware: []) # rubocop:disable Metrics/AbcSize
call_handler = Kernel.proc do
if handler.arity == 2
begin
in_arg = handler.handler_io.input_serde.deserialize(in_buffer)
rescue StandardError => e
Kernel.raise TerminalError, "Unable to parse input argument: #{e.message}"
end
handler.callable.call(ctx, in_arg)
else
handler.callable.call(ctx)
end
out_arg = handler.callable.call(ctx, in_arg)
else
out_arg = handler.callable.call(ctx)
end

# Build the middleware chain so each middleware can use `yield` to call the next.
chain = middleware.reverse.reduce(call_handler) do |nxt, mw|
Kernel.proc { mw.call(handler, ctx, &nxt) }
end

out_arg = chain.call
handler.handler_io.output_serde.serialize(out_arg)
end
end
3 changes: 2 additions & 1 deletion lib/restate/server.rb
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ def process_invocation(env, handler, request_headers)
handler: handler,
invocation: invocation,
send_output: send_output,
input_queue: input_queue
input_queue: input_queue,
middleware: @endpoint.middleware
)

# Spawn the handler as an async task so the response body can stream
Expand Down
11 changes: 8 additions & 3 deletions lib/restate/server_context.rb
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,19 @@ class ServerContext
sig { returns(T.untyped) }
attr_reader :invocation

sig { params(vm: VMWrapper, handler: T.untyped, invocation: T.untyped, send_output: T.untyped, input_queue: Async::Queue).void }
def initialize(vm:, handler:, invocation:, send_output:, input_queue:)
sig do
params(vm: VMWrapper, handler: T.untyped, invocation: T.untyped, send_output: T.untyped,
input_queue: Async::Queue, middleware: T::Array[T.untyped]).void
end
def initialize(vm:, handler:, invocation:, send_output:, input_queue:, middleware: [])
@vm = T.let(vm, VMWrapper)
@handler = T.let(handler, T.untyped)
@invocation = T.let(invocation, T.untyped)
@send_output = T.let(send_output, T.untyped)
@input_queue = T.let(input_queue, Async::Queue)
@run_coros_to_execute = T.let({}, T::Hash[Integer, T.untyped])
@attempt_finished_event = T.let(AttemptFinishedEvent.new, AttemptFinishedEvent)
@middleware = T.let(middleware, T::Array[T.untyped])
end

# ── Main entry point ──
Expand All @@ -48,7 +52,8 @@ def enter
Thread.current[:restate_service_kind] = @handler.service_tag.kind
Thread.current[:restate_handler_kind] = @handler.kind
in_buffer = @invocation.input_buffer
out_buffer = Restate.invoke_handler(handler: @handler, ctx: self, in_buffer: in_buffer)
out_buffer = Restate.invoke_handler(handler: @handler, ctx: self, in_buffer: in_buffer,
middleware: @middleware)
@vm.sys_write_output_success(out_buffer.b)
@vm.sys_end
rescue TerminalError => e
Expand Down
9 changes: 7 additions & 2 deletions lib/restate/testing.rb
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,17 @@ class RestateTestHarness
# @param restate_image [String] Docker image for Restate server.
# @param always_replay [Boolean] Force replay on every suspension point.
# @param disable_retries [Boolean] Disable Restate retry policy.
# @yield [Endpoint] Optional block to configure the endpoint (e.g. add middleware).
def initialize(*services,
restate_image: 'docker.io/restatedev/restate:latest',
always_replay: false,
disable_retries: false)
disable_retries: false,
&configure)
@services = services
@restate_image = restate_image
@always_replay = always_replay
@disable_retries = disable_retries
@configure = configure
@server_thread = nil
@container = nil
@port = nil
Expand All @@ -69,7 +72,9 @@ def initialize(*services,

def start
@port = find_free_port
rack_app = Restate.endpoint(*@services).app
endpoint = Restate.endpoint(*@services)
@configure&.call(endpoint)
rack_app = endpoint.app
start_sdk_server(rack_app)
wait_for_tcp(@port)
start_restate_container
Expand Down
40 changes: 37 additions & 3 deletions spec/harness_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,32 @@ def greet(_ctx, request)
end
end

# Middleware that stores the invocation in fiber-local storage so the handler can see it ran.
class TestHeaderMiddleware
def call(handler, ctx)
team_id = ctx.request.headers['x-team-id']
Thread.current[:test_team_id] = team_id
yield
ensure
Thread.current[:test_team_id] = nil
end
end

class MiddlewareTestService < Restate::Service
handler def check_header(ctx, _input)
team = Thread.current[:test_team_id] || 'none'
"team:#{team}"
end
end

# ── Helpers ──────────────────────────────────────────────────

def post_json(base_url, path, body)
def post_json(base_url, path, body, headers: {})
uri = URI("#{base_url}#{path}")
request = Net::HTTP::Post.new(uri)
request["Content-Type"] = "application/json"
request["idempotency-key"] = SecureRandom.uuid
headers.each { |k, v| request[k] = v }
request.body = JSON.generate(body)
Net::HTTP.start(uri.hostname, uri.port, read_timeout: 30) { |http| http.request(request) }
end
Expand All @@ -112,8 +131,10 @@ def post_json(base_url, path, body)
before(:all) do
@harness = Restate::Testing::RestateTestHarness.new(
TestGreeter, TestCounter, TestWorker, TestOrchestrator, TestRunSync, TestFiberLocalCtx,
TStructGreeter, TypedGreeter
)
TStructGreeter, TypedGreeter, MiddlewareTestService
) do |endpoint|
endpoint.use(TestHeaderMiddleware)
end
@harness.start
end

Expand Down Expand Up @@ -186,4 +207,17 @@ def post_json(base_url, path, body)
expect(response.code).to eq("200")
expect(JSON.parse(response.body)).to eq("Hi, World!")
end

it "runs handler middleware that extracts headers" do
response = post_json(@harness.ingress_url, "/MiddlewareTestService/check_header", nil,
headers: { "x-team-id" => "acme-corp" })
expect(response.code).to eq("200")
expect(JSON.parse(response.body)).to eq("team:acme-corp")
end

it "runs handler middleware with missing header" do
response = post_json(@harness.ingress_url, "/MiddlewareTestService/check_header", nil)
expect(response.code).to eq("200")
expect(JSON.parse(response.body)).to eq("team:none")
end
end
Loading