-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcapability_checker.go
More file actions
313 lines (267 loc) · 10.5 KB
/
capability_checker.go
File metadata and controls
313 lines (267 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
package hostlib
import (
"context"
"encoding/json"
"fmt"
"net/url"
"os"
"strconv"
"strings"
"github.com/reglet-dev/reglet-abi/hostfunc"
"github.com/reglet-dev/reglet-host-sdk/policy"
)
// CapabilityChecker checks if operations are allowed based on granted capabilities.
// It uses the SDK's typed Policy for capability enforcement.
type CapabilityChecker struct {
policy policy.Policy
grantedCapabilities map[string]*hostfunc.GrantSet
cwd string // Current working directory for resolving relative paths
denialHandler DenialHandler
}
// DenialHandler is called when a capability is denied.
// It allows custom logging or auditing.
type DenialHandler func(ctx context.Context, pluginName, capabilityKind, pattern, message string)
// CapabilityCheckerOption configures a CapabilityChecker.
type CapabilityCheckerOption func(*capabilityCheckerConfig)
type capabilityCheckerConfig struct {
cwd string
symlinkResolution bool
denialHandler DenialHandler
}
// WithCapabilityWorkingDirectory sets the working directory for path resolution.
func WithCapabilityWorkingDirectory(cwd string) CapabilityCheckerOption {
return func(c *capabilityCheckerConfig) {
c.cwd = cwd
}
}
// WithCapabilitySymlinkResolution enables or disables symlink resolution.
func WithCapabilitySymlinkResolution(enabled bool) CapabilityCheckerOption {
return func(c *capabilityCheckerConfig) {
c.symlinkResolution = enabled
}
}
// WithCapabilityDenialHandler sets the handler for denied capabilities.
func WithCapabilityDenialHandler(handler DenialHandler) CapabilityCheckerOption {
return func(c *capabilityCheckerConfig) {
c.denialHandler = handler
}
}
// NewCapabilityChecker creates a new capability checker with the given capabilities.
// The cwd is obtained at construction time to avoid side-effects during capability checks.
func NewCapabilityChecker(caps map[string]*hostfunc.GrantSet, opts ...CapabilityCheckerOption) *CapabilityChecker {
cfg := capabilityCheckerConfig{
symlinkResolution: true,
}
for _, opt := range opts {
opt(&cfg)
}
// Get cwd if not provided
if cfg.cwd == "" {
cfg.cwd, _ = os.Getwd()
}
return &CapabilityChecker{
policy: policy.NewPolicy(
policy.WithWorkingDirectory(cfg.cwd),
policy.WithSymlinkResolution(cfg.symlinkResolution),
),
grantedCapabilities: caps,
cwd: cfg.cwd,
denialHandler: cfg.denialHandler,
}
}
// RegisterGrants adds or updates granted capabilities for a specific plugin.
func (c *CapabilityChecker) RegisterGrants(pluginName string, grants *hostfunc.GrantSet) {
if c.grantedCapabilities == nil {
c.grantedCapabilities = make(map[string]*hostfunc.GrantSet)
}
c.grantedCapabilities[pluginName] = grants
}
// CheckNetwork performs typed network capability check.
func (c *CapabilityChecker) CheckNetwork(ctx context.Context, pluginName string, req hostfunc.NetworkRequest) error {
grants, ok := c.grantedCapabilities[pluginName]
if !ok || grants == nil {
return c.handleDeny(ctx, pluginName, "network", fmt.Sprintf("%s:%d", req.Host, req.Port), "no capabilities granted")
}
if c.policy.CheckNetwork(req, grants) {
return nil
}
return c.handleDeny(ctx, pluginName, "network", fmt.Sprintf("%s:%d", req.Host, req.Port), "network capability denied")
}
// CheckNetworkConnection checks if a specific network connection (host:port) is allowed.
func (c *CapabilityChecker) CheckNetworkConnection(ctx context.Context, pluginName, host string, port int) error {
grants, ok := c.grantedCapabilities[pluginName]
if !ok || grants == nil {
return c.handleDeny(ctx, pluginName, "network", fmt.Sprintf("%s:%d", host, port), "no capabilities granted")
}
req := hostfunc.NetworkRequest{Host: host, Port: port}
// 1. Silent Check
if c.policy.EvaluateNetwork(req, grants) {
return nil
}
// 2. Loud Check
c.policy.CheckNetwork(req, grants)
return c.handleDeny(ctx, pluginName, "network", fmt.Sprintf("%s:%d", host, port), "network capability denied")
}
// CheckFileSystem performs typed filesystem capability check.
func (c *CapabilityChecker) CheckFileSystem(ctx context.Context, pluginName string, req hostfunc.FileSystemRequest) error {
grants, ok := c.grantedCapabilities[pluginName]
if !ok || grants == nil {
return c.handleDeny(ctx, pluginName, "fs", req.Path, "no capabilities granted")
}
if c.policy.CheckFileSystem(req, grants) {
return nil
}
return c.handleDeny(ctx, pluginName, "fs", req.Path, "filesystem capability denied")
}
// CheckEnvironment performs typed environment capability check.
func (c *CapabilityChecker) CheckEnvironment(ctx context.Context, pluginName string, req hostfunc.EnvironmentRequest) error {
grants, ok := c.grantedCapabilities[pluginName]
if !ok || grants == nil {
return c.handleDeny(ctx, pluginName, "env", req.Variable, "no capabilities granted")
}
if c.policy.CheckEnvironment(req, grants) {
return nil
}
return c.handleDeny(ctx, pluginName, "env", req.Variable, "environment capability denied")
}
// CheckExec performs typed exec capability check.
func (c *CapabilityChecker) CheckExec(ctx context.Context, pluginName string, req hostfunc.ExecCapabilityRequest) error {
grants, ok := c.grantedCapabilities[pluginName]
if !ok || grants == nil {
return c.handleDeny(ctx, pluginName, "exec", req.Command, "no capabilities granted")
}
if c.policy.CheckExec(req, grants) {
return nil
}
return c.handleDeny(ctx, pluginName, "exec", req.Command, "exec capability denied")
}
func (c *CapabilityChecker) handleDeny(ctx context.Context, pluginName, kind, pattern, message string) error {
fullMsg := fmt.Sprintf("%s: %s", message, pattern)
if c.denialHandler != nil {
c.denialHandler(ctx, pluginName, kind, pattern, fullMsg)
}
return fmt.Errorf("%s", fullMsg)
}
// AllowsPrivateNetwork checks if the plugin is allowed to access private network addresses.
func (c *CapabilityChecker) AllowsPrivateNetwork(pluginName string) bool {
grants, ok := c.grantedCapabilities[pluginName]
if !ok || grants == nil {
return false
}
req := hostfunc.NetworkRequest{Host: "127.0.0.1", Port: 0}
return c.policy.EvaluateNetwork(req, grants)
}
// ToCapabilityGetter returns a CapabilityGetter function that uses this checker.
func (c *CapabilityChecker) ToCapabilityGetter(ctx context.Context, pluginName string) CapabilityGetter {
return func(plugin, capability string) bool {
if varName, found := strings.CutPrefix(capability, "env:"); found {
if err := c.CheckEnvironment(ctx, pluginName, hostfunc.EnvironmentRequest{Variable: varName}); err == nil {
return true
}
if err := c.CheckExec(ctx, pluginName, hostfunc.ExecCapabilityRequest{Command: capability}); err == nil {
return true
}
return false
}
err := c.CheckExec(ctx, pluginName, hostfunc.ExecCapabilityRequest{Command: capability})
return err == nil
}
}
// CapabilityMiddleware returns a middleware that enforces capabilities for standard host functions.
func CapabilityMiddleware(checker *CapabilityChecker) Middleware {
return func(next ByteHandler) ByteHandler {
return func(ctx context.Context, payload []byte) ([]byte, error) {
funcName := ""
if hc, ok := ctx.(HostContext); ok {
funcName = hc.FunctionName()
}
pluginName, ok := CapabilityPluginNameFromContext(ctx)
if !ok {
return next(ctx, payload)
}
// Add SSRF protection context based on plugin capabilities
allowPrivate := checker.AllowsPrivateNetwork(pluginName)
ctx = context.WithValue(ctx, "ssrf_allow_private", allowPrivate)
// Validate capability based on function name and payload
switch funcName {
case "dns_lookup":
var req hostfunc.DNSRequest
if err := json.Unmarshal(payload, &req); err == nil {
if err := checker.CheckNetwork(ctx, pluginName, hostfunc.NetworkRequest{Host: req.Hostname, Port: 53}); err != nil {
return NewValidationError(err.Error()).ToJSON(), nil
}
}
case "tcp_connect":
var req hostfunc.TCPRequest
if err := json.Unmarshal(payload, &req); err == nil {
port, _ := strconv.Atoi(req.Port)
if err := checker.CheckNetwork(ctx, pluginName, hostfunc.NetworkRequest{Host: req.Host, Port: port}); err != nil {
return NewValidationError(err.Error()).ToJSON(), nil
}
}
case "smtp_connect":
var req hostfunc.SMTPRequest
if err := json.Unmarshal(payload, &req); err == nil {
port, _ := strconv.Atoi(req.Port)
if err := checker.CheckNetwork(ctx, pluginName, hostfunc.NetworkRequest{Host: req.Host, Port: port}); err != nil {
return NewValidationError(err.Error()).ToJSON(), nil
}
}
case "http_request":
var req hostfunc.HTTPRequest
if err := json.Unmarshal(payload, &req); err == nil {
if err := checkHTTPCapability(ctx, checker, pluginName, req.URL); err != nil {
return NewValidationError(err.Error()).ToJSON(), nil
}
}
case "exec_command":
var req hostfunc.ExecRequest
if err := json.Unmarshal(payload, &req); err == nil {
// Detection logic
execType := GetExecutionTypeDescription(req.Command, req.Args)
if IsDangerousExecution(req.Command, req.Args) {
if err := checker.CheckExec(ctx, pluginName, hostfunc.ExecCapabilityRequest{Command: req.Command}); err != nil {
msg := fmt.Sprintf("%s requires 'exec:%s' capability", execType, req.Command)
return NewValidationError(msg).ToJSON(), nil
}
} else {
if err := checker.CheckExec(ctx, pluginName, hostfunc.ExecCapabilityRequest{Command: req.Command}); err != nil {
return NewValidationError(err.Error()).ToJSON(), nil
}
}
}
}
return next(ctx, payload)
}
}
}
func checkHTTPCapability(ctx context.Context, checker *CapabilityChecker, pluginName, rawURL string) error {
parsedURL, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("invalid URL: %w", err)
}
portStr := parsedURL.Port()
if portStr == "" {
if parsedURL.Scheme == "https" {
portStr = "443"
} else {
portStr = "80"
}
}
port, _ := strconv.Atoi(portStr)
return checker.CheckNetworkConnection(ctx, pluginName, parsedURL.Hostname(), port)
}
// Context helpers for plugin name propagation
type capabilityContextKey struct {
name string
}
var pluginNameContextKey = &capabilityContextKey{name: "plugin_name"}
// WithCapabilityPluginName adds the plugin name to the context.
func WithCapabilityPluginName(ctx context.Context, name string) context.Context {
return context.WithValue(ctx, pluginNameContextKey, name)
}
// CapabilityPluginNameFromContext retrieves the plugin name from the context.
func CapabilityPluginNameFromContext(ctx context.Context) (string, bool) {
name, ok := ctx.Value(pluginNameContextKey).(string)
return name, ok
}