diff --git a/.gitignore b/.gitignore index 02a57e664..70b7e4302 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ dist/* conf/wasm_plugin .DS_Store +.git* diff --git a/bfe_basic/condition/build.go b/bfe_basic/condition/build.go index 9f57b55d4..5a5c0ce26 100644 --- a/bfe_basic/condition/build.go +++ b/bfe_basic/condition/build.go @@ -548,6 +548,17 @@ func buildPrimitive(node *parser.CallExpr) (Condition, error) { fetcher: &BfeTimeFetcher{}, matcher: matcher, }, nil + + case "req_body_json_in": + return &PrimitiveCond{ + name: node.Fun.Name, + node: node, + fetcher: &ReqBodyJsonFetcher{ + path: node.Args[0].Value, + }, + matcher: NewInMatcher(node.Args[1].Value, node.Args[2].ToBool()), + }, nil + default: return nil, fmt.Errorf("unsupported primitive %s", node.Fun.Name) } diff --git a/bfe_basic/condition/parser/semant.go b/bfe_basic/condition/parser/semant.go index 65f88eb7b..b47ededab 100644 --- a/bfe_basic/condition/parser/semant.go +++ b/bfe_basic/condition/parser/semant.go @@ -79,6 +79,7 @@ var funcProtos = map[string][]Token{ "req_context_value_in": {STRING, STRING, BOOL}, "bfe_time_range": []Token{STRING, STRING}, "bfe_periodic_time_range": []Token{STRING, STRING, STRING}, + "req_body_json_in": []Token{STRING, STRING, BOOL}, } func prototypeCheck(expr *CallExpr) error { diff --git a/bfe_basic/condition/primitive.go b/bfe_basic/condition/primitive.go index 6967bac8f..cc61c3e71 100644 --- a/bfe_basic/condition/primitive.go +++ b/bfe_basic/condition/primitive.go @@ -26,14 +26,14 @@ import ( "strconv" "strings" "time" -) -import ( "github.com/bfenetworks/bfe/bfe_basic" "github.com/bfenetworks/bfe/bfe_basic/condition/parser" "github.com/bfenetworks/bfe/bfe_util" "github.com/bfenetworks/bfe/bfe_util/net_util" "github.com/spaolacci/murmur3" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) const ( @@ -1089,3 +1089,70 @@ func (t *PeriodicTimeMatcher) Match(v interface{}) bool { seconds := hour*3600 + minute*60 + second return seconds >= t.startTime && seconds <= t.endTime } + +type ReqBodyJsonFetcher struct{ + path string +} + +func (pf *ReqBodyJsonFetcher) Fetch(req *bfe_basic.Request) (interface{}, error) { + return ReqBodyJsonFetch(req, pf.path) +} + +func ReqBodyJsonFetch(req *bfe_basic.Request, path string) (string, error) { + const jsonCachePrefix = "jsoncache." + + if req == nil || req.HttpRequest == nil { + return "", fmt.Errorf("fetcher: nil pointer") + } + + cachepath := jsonCachePrefix + path + cachedVal := req.GetContext(cachepath) + if cachedVal != nil { + str, ok := cachedVal.(string) + if ok { + return str, nil + } + } + + bodyAccessor, err := req.HttpRequest.GetBodyAccessor() + if bodyAccessor == nil { + return "", err + } + + body, _ := bodyAccessor.GetBytes() + val := gjson.GetBytes(body, path) + if !val.Exists() { + req.SetContext(cachepath, "") + return "", nil + } + + str := val.String() + req.SetContext(cachepath, str) + return str, nil +} + +func ReqBodyJsonSet(req *bfe_basic.Request, path string, value string) error { + if req == nil || req.OutRequest == nil { + return fmt.Errorf("set json body error: nil pointer") + } + + bodyAccessor, err := req.OutRequest.GetBodyAccessor() + if bodyAccessor == nil { + return err + } + + body, _ := bodyAccessor.GetBytes() + var newBody []byte + if path == "" { + newBody = []byte(value) + } else { + newBody, err = sjson.SetBytes(body, path, value) + if err != nil { + return fmt.Errorf("set json body error, path: %s, value: %s, err: %v", path, value, err) + } + } + + bodyAccessor.SetBytes(newBody, false) + + return nil +} diff --git a/bfe_basic/error_code.go b/bfe_basic/error_code.go index 48ae6d2d9..60191c342 100644 --- a/bfe_basic/error_code.go +++ b/bfe_basic/error_code.go @@ -47,6 +47,7 @@ var ( ErrBkRetryTooMany = errors.New("BK_RETRY_TOOMANY") // reach retry max ErrBkNoSubClusterCross = errors.New("BK_NO_SUB_CLUSTER_CROSS") // no sub-cluster found ErrBkCrossRetryBalance = errors.New("BK_CROSS_RETRY_BALANCE") // cross retry balance failed + ErrBkBodyProcess = errors.New("BK_BODY_PROCESS") // body process error // GSLB error ErrGslbBlackhole = errors.New("GSLB_BLACKHOLE") // deny by blackhole diff --git a/bfe_config/bfe_cluster_conf/cluster_conf/cluster_conf_load.go b/bfe_config/bfe_cluster_conf/cluster_conf/cluster_conf_load.go index 11ae55194..d75c9c59a 100644 --- a/bfe_config/bfe_cluster_conf/cluster_conf/cluster_conf_load.go +++ b/bfe_config/bfe_cluster_conf/cluster_conf/cluster_conf_load.go @@ -125,6 +125,12 @@ type BackendHTTPS struct { protocol string // protocol of backend https } +type AIConf struct { + Type int // type of LLM service, reserved for future use. should be 0 now. + ModelMapping *map[string]string // model mapping, key is model name in req, value is model name in backend + Key *string // API key for AI service +} + func (conf *BackendHTTPS) GetProtocol() string { return conf.protocol } @@ -222,6 +228,7 @@ type ClusterConf struct { GslbBasic *GslbBasicConf // gslb basic conf for cluster ClusterBasic *ClusterBasicConf // basic conf for cluster HTTPSConf *BackendHTTPS // backend's https conf + AIConf *AIConf // ai conf for cluster } type ClusterToConf map[string]ClusterConf diff --git a/bfe_http/request.go b/bfe_http/request.go index 9a96d5fa8..397b4ead7 100644 --- a/bfe_http/request.go +++ b/bfe_http/request.go @@ -34,10 +34,9 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" -) -import ( "github.com/bfenetworks/bfe/bfe_bufio" "github.com/bfenetworks/bfe/bfe_net/textproto" "github.com/bfenetworks/bfe/bfe_tls" @@ -974,3 +973,49 @@ func (r *Request) closeBody() { r.Body.Close() } } +func (r *Request) GetBodyAccessor() (BodyAccessor, error) { + if r.Body == nil { + return nil, nil + } + body := r.Body + for { + bodyAccessor, ok := body.(BodyAccessor) + if ok { + return bodyAccessor, nil + } + + sourcer, ok := body.(SourceGetter) + if !ok { + break + } + + body = sourcer.GetSource() + } + + // If the body is not a BodyAccessor, we will try to convert it to a BytesBody + var err error + r.Body, err = NewBytesBody(r.Body, GetAccessibleBodySize()) + if err != nil { + return nil, fmt.Errorf("can't get body") + } + return r.Body.(BodyAccessor), nil +} + +const DefaultAccessibleBodySize = 1024*1024*2 +const MaxAccessibleBodySize = 1024*1024*8 +var accessibleBodySize = int64(DefaultAccessibleBodySize) + +func SetAccessibleBodySize(size int64) { + if size <= 0 { + size = DefaultAccessibleBodySize + } + atomic.StoreInt64(&accessibleBodySize, size) +} + +func GetAccessibleBodySize() int64 { + return atomic.LoadInt64(&accessibleBodySize) +} + +type SourceGetter interface { + GetSource() io.ReadCloser +} diff --git a/bfe_http/transfer.go b/bfe_http/transfer.go index a6daeb6dd..b6b96ca6d 100644 --- a/bfe_http/transfer.go +++ b/bfe_http/transfer.go @@ -749,3 +749,78 @@ func parseContentLength(cl string) (int64, error) { return n, nil } + +type BodyAccessor interface { + GetBytes() ([]byte, bool) + SetBytes([]byte, bool) +} + +//body with BodyAccessor interface +type bytes_body struct { + src io.ReadCloser // source body + buf []byte // bytes read out from src + all bool // all already read out from src to buf + r io.Reader // multiReader of buf and src +} + +func (b *bytes_body) Read(p []byte) (n int, err error) { + return b.r.Read(p) +} + +func (b *bytes_body) Close() error { + return b.src.Close() +} + +func (b *bytes_body) Peek(n int) ([]byte, error) { + if n < 0 { + return nil, fmt.Errorf("negative peek count") + } + if n > len(b.buf) { + n = len(b.buf) + } + return b.buf[:n], nil +} + +func (b *bytes_body) ForcePeek(n int) ([]byte, error) { + return b.Peek(n) +} + +func (b *bytes_body) GetBytes() ([]byte, bool) { + return b.buf, b.all +} + +func (b *bytes_body) SetBytes(newBuf []byte, all bool) { + b.buf = newBuf + br := bytes.NewBuffer(newBuf) + b.all = b.all || all + if b.all { + b.r = br + } else { + b.r = io.MultiReader(br, b.src) + } +} + +func NewBytesBody(src io.ReadCloser, maxSize int64) (*bytes_body, error) { + bb, err := io.ReadAll(io.LimitReader(src, maxSize)) + if err != nil { + return nil, fmt.Errorf("io.ReadAll: %s", err.Error()) + } + + br := bytes.NewBuffer(bb) + + if len(bb) < int(maxSize) { + return &bytes_body{ + src: src, + buf: bb, + all: true, + r: br, + }, nil + } else { + return &bytes_body{ + src: src, + buf: bb, + all: false, + r: io.MultiReader(br, src), + }, nil + } +} diff --git a/bfe_modules/bfe_modules.go b/bfe_modules/bfe_modules.go index 16bb4614b..308e13900 100644 --- a/bfe_modules/bfe_modules.go +++ b/bfe_modules/bfe_modules.go @@ -19,10 +19,12 @@ package bfe_modules import ( "github.com/bfenetworks/bfe/bfe_module" "github.com/bfenetworks/bfe/bfe_modules/mod_access" + "github.com/bfenetworks/bfe/bfe_modules/mod_ai_token_auth" "github.com/bfenetworks/bfe/bfe_modules/mod_auth_basic" "github.com/bfenetworks/bfe/bfe_modules/mod_auth_jwt" "github.com/bfenetworks/bfe/bfe_modules/mod_auth_request" "github.com/bfenetworks/bfe/bfe_modules/mod_block" + "github.com/bfenetworks/bfe/bfe_modules/mod_body_process" "github.com/bfenetworks/bfe/bfe_modules/mod_compress" "github.com/bfenetworks/bfe/bfe_modules/mod_cors" "github.com/bfenetworks/bfe/bfe_modules/mod_doh" @@ -139,7 +141,12 @@ var moduleList = []bfe_module.BfeModule{ // mod_unified_waf mod_unified_waf.NewModuleWaf(), -} + + // mod_ai_token_auth + mod_ai_token_auth.NewModuleAITokenAuth(), + + // mod_body_process + mod_body_process.NewModuleBodyProcess(),} // init modules list func InitModuleList(modules []bfe_module.BfeModule) { diff --git a/bfe_modules/mod_ai_token_auth/conf_mod_ai_token_auth.go b/bfe_modules/mod_ai_token_auth/conf_mod_ai_token_auth.go new file mode 100644 index 000000000..31a45c062 --- /dev/null +++ b/bfe_modules/mod_ai_token_auth/conf_mod_ai_token_auth.go @@ -0,0 +1,95 @@ +// Copyright (c) 2025 The BFE Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mod_ai_token_auth + +import ( + "fmt" + + "github.com/baidu/go-lib/log" + "github.com/bfenetworks/bfe/bfe_util" + "github.com/bfenetworks/bfe/bfe_util/redis_client" + gcfg "gopkg.in/gcfg.v1" +) + +type ConfModAITokenAuth struct { + Basic struct { + ProductRulePath string + } + + // redis conf + Redis struct { + Bns string // bns name for redis proxy + ConnectTimeout int // connect timeout (ms) + ReadTimeout int // read timeout (ms) + WriteTimeout int // write timeout(ms) + + // max idle connections in pool + MaxIdle int + + // redis password,ignore if not set + Password string + + // max active connections in pool, + // when set 0, there is no connection num limit + MaxActive int + } + + Log struct { + OpenDebug bool + } +} + +func (cfg *ConfModAITokenAuth) Check(confRoot string) error { + if cfg.Basic.ProductRulePath == "" { + log.Logger.Warn("ModAITokenAuth.ProductRulePath not set, use default value") + cfg.Basic.ProductRulePath = "mod_ai_toekn_auth/token_rule.data" + } + + cfg.Basic.ProductRulePath = bfe_util.ConfPathProc(cfg.Basic.ProductRulePath, confRoot) + + // check redis server conf + if err := redis_client.CheckRedisConf(cfg.Redis.Bns); err != nil { + return err + } + + // check connectTimeOut + if cfg.Redis.ConnectTimeout <= 0 { + return fmt.Errorf("Redis.ConnectTimeout must > 0") + } + + // check Read/Write Timeout + if cfg.Redis.ReadTimeout <= 0 || cfg.Redis.WriteTimeout <= 0 { + return fmt.Errorf("Redis.ReadTimeout/WriteTimeout must > 0") + } + + return nil +} + +func ConfLoad(filePath string, confRoot string) (*ConfModAITokenAuth, error) { + var cfg ConfModAITokenAuth + var err error + + err = gcfg.ReadFileInto(&cfg, filePath) + if err != nil { + return &cfg, err + } + + err = cfg.Check(confRoot) + if err != nil { + return &cfg, err + } + + return &cfg, nil +} diff --git a/bfe_modules/mod_ai_token_auth/mod_ai_token_auth.go b/bfe_modules/mod_ai_token_auth/mod_ai_token_auth.go new file mode 100644 index 000000000..ced113249 --- /dev/null +++ b/bfe_modules/mod_ai_token_auth/mod_ai_token_auth.go @@ -0,0 +1,324 @@ +// Copyright (c) 2025 The BFE Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mod_ai_token_auth + +import ( + "fmt" + "net/url" + "strings" + + "github.com/baidu/go-lib/log" + "github.com/baidu/go-lib/web-monitor/metrics" + "github.com/baidu/go-lib/web-monitor/web_monitor" + + "github.com/bfenetworks/bfe/bfe_basic" + "github.com/bfenetworks/bfe/bfe_http" + "github.com/bfenetworks/bfe/bfe_module" + "github.com/bfenetworks/bfe/bfe_util/redis_client" +) + +const ( + ModAITokenAuth = "mod_ai_token_auth" +) + +var ( + openDebug = false +) + +type ModuleAITokenAuthState struct { + ReqTotal *metrics.Counter + ReqAuth *metrics.Counter + ReqAuthFail *metrics.Counter +} + +type ModuleAITokenAuth struct { + name string + conf *ConfModAITokenAuth + ruleTable *TokenRuleTable + state ModuleAITokenAuthState + metrics metrics.Metrics + + redisClient redis_client.Client // redis client +} + +func NewModuleAITokenAuth() *ModuleAITokenAuth { + m := new(ModuleAITokenAuth) + m.name = ModAITokenAuth + m.metrics.Init(&m.state, ModAITokenAuth, 0) + m.ruleTable = NewTokenRuleTable() + return m +} + +func (m *ModuleAITokenAuth) Name() string { + return m.name +} + +func (m *ModuleAITokenAuth) loadProductRuleConf(query url.Values) error { + path := query.Get("path") + if path == "" { + path = m.conf.Basic.ProductRulePath + } + + conf, err := ProductRuleConfLoad(path) + if err != nil { + return fmt.Errorf("err in ProductRuleConfLoad(%s): %s", path, err) + } + + oldtokens := m.ruleTable.Update(conf) + // clean old tokens' used quota in redis + for _, t := range oldtokens { + key := usedQuotaKey(t.Key, t.UpdateTime) + m.redisClient.Expire(key, 3600) + } + + return nil +} + +func (m *ModuleAITokenAuth) matchTokenRule(req *bfe_basic.Request) bool { + if openDebug { + log.Logger.Debug("%s check request", m.name) + } + m.state.ReqTotal.Inc(1) + + rules, ok := m.ruleTable.Search(req.Route.Product) + if !ok { + if openDebug { + log.Logger.Debug("%s product %s not found, just pass", m.name, req.Route.Product) + } + return false + } + + for _, rule := range *rules { + if openDebug { + log.Logger.Debug("%s process rule: %v", m.name, rule) + } + + if rule.Cond.Match(req) { + return true + } + } + + return false +} + +func (m *ModuleAITokenAuth) tokenReadResponseHandler(req *bfe_basic.Request, res *bfe_http.Response) int { + ctx := GetTokenAuthContext(req) // ensure token auth context is set + if ctx == nil { + return bfe_module.BfeHandlerGoOn + } + + if res.ContentLength >= 0 { + ctx.CompletionTokens = int64(res.ContentLength) / 4 // estimate completion tokens + } + return bfe_module.BfeHandlerGoOn +} + +func CalcReqUsedQuota(req *bfe_basic.Request, promptTokens, completionTokens int64) int64 { + // calculate used quota based on prompt and completion tokens + if promptTokens < 0 || completionTokens < 0 { + return 0 + } + return promptTokens + completionTokens +} + +func (m *ModuleAITokenAuth) tokenRequestFinishHandler(req *bfe_basic.Request, res *bfe_http.Response) int { + ctx := GetTokenAuthContext(req) // ensure token auth context is set + if ctx == nil { + return bfe_module.BfeHandlerGoOn + } + + ctx.UsedQuota = CalcReqUsedQuota(req, ctx.PromptTokens, ctx.CompletionTokens) // calculate used quota + if ctx.UsedQuota > 0 { + m.IncrTokenUsedQuotaBy(ctx.Token, ctx.UsedQuota) // increment token used quota + } + + return bfe_module.BfeHandlerGoOn +} + +func SetApiKey(req *bfe_http.Request, apiKey string) { + // set api key to Authorization header + if apiKey == "" { + return + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) +} + +func GetApiKey(req *bfe_basic.Request) string { + // get api key from Authorization header + authHeader := req.HttpRequest.Header.Get("Authorization") + if authHeader == "" { + return "" + } + + // remove "Bearer " prefix if exists + authHeader = strings.TrimPrefix(authHeader, "Bearer ") + authHeader = strings.TrimPrefix(authHeader, "sk-") + + return authHeader +} + +// found product handler +func (m *ModuleAITokenAuth) tokenFoundProductHandler(req *bfe_basic.Request) (int, *bfe_http.Response) { + matched := m.matchTokenRule(req) + if !matched { + // no rule, just pass + return bfe_module.BfeHandlerGoOn, nil + } + + // do token authentication + m.state.ReqAuth.Inc(1) + tok, err := m.ValidateUserTokenByReq(req) + if err != nil { + m.state.ReqAuthFail.Inc(1) + resp := bfe_basic.CreateSpecifiedContentResp(req, bfe_http.StatusUnauthorized, "text/plain", + fmt.Sprintf("token authentication failed: %s", err.Error())) + return bfe_module.BfeHandlerResponse, resp + } + + promptToken := GetPromptToken(req) + SetTokenAuthContext(req, &TokenAuthContext{ + Token: tok, + PromptTokens: promptToken, + CompletionTokens: -1, // -1 - unknown + }) + + return bfe_module.BfeHandlerGoOn, nil +} + +func (m *ModuleAITokenAuth) getState(params map[string][]string) ([]byte, error) { + s := m.metrics.GetAll() + return s.Format(params) +} + +func (m *ModuleAITokenAuth) getStateDiff(params map[string][]string) ([]byte, error) { + s := m.metrics.GetDiff() + return s.Format(params) +} + +func (m *ModuleAITokenAuth) monitorHandlers() map[string]interface{} { + handlers := map[string]interface{}{ + m.name: m.getState, + m.name + ".diff": m.getStateDiff, + } + return handlers +} + +func (m *ModuleAITokenAuth) reloadHandlers() map[string]interface{} { + handlers := map[string]interface{}{ + m.name: m.loadProductRuleConf, + } + return handlers +} + +func (m *ModuleAITokenAuth) Init(cbs *bfe_module.BfeCallbacks, whs *web_monitor.WebHandlers, + cr string) error { + var err error + + confPath := bfe_module.ModConfPath(cr, m.name) + if m.conf, err = ConfLoad(confPath, cr); err != nil { + return fmt.Errorf("%s: conf load err %v", m.name, err) + } + openDebug = m.conf.Log.OpenDebug + + // new Redis Client + r := m.conf.Redis + options := &redis_client.Options{ + ServiceConf: r.Bns, + MaxIdle: r.MaxIdle, + MaxActive: r.MaxActive, + Wait: false, + ConnTimeoutMs: r.ConnectTimeout, + ReadTimeoutMs: r.ReadTimeout, + WriteTimeoutMs: r.WriteTimeout, + Password: r.Password, + } + + client := redis_client.NewRedisClient(options) + m.redisClient = client + + if err = m.loadProductRuleConf(nil); err != nil { + return fmt.Errorf("%s: loadProductRuleConf() err %v", m.name, err) + } + + err = cbs.AddFilter(bfe_module.HandleFoundProduct, m.tokenFoundProductHandler) + if err != nil { + return fmt.Errorf("%s.Init(): AddFilter(m.tokenFoundProductHandler): %s", m.name, err.Error()) + } + + err = cbs.AddFilter(bfe_module.HandleReadResponse, m.tokenReadResponseHandler) + if err != nil { + return fmt.Errorf("%s.Init(): AddFilter(m.tokenReadResponseHandler): %v", m.name, err) + } + + err = cbs.AddFilter(bfe_module.HandleRequestFinish, m.tokenRequestFinishHandler) + if err != nil { + return fmt.Errorf("%s.Init(): AddFilter(m.tokenReadResponseHandler): %v", m.name, err) + } + + err = web_monitor.RegisterHandlers(whs, web_monitor.WebHandleMonitor, m.monitorHandlers()) + if err != nil { + return fmt.Errorf("%s.Init(): RegisterHandlers(m.monitorHandlers): %v", m.name, err) + } + + err = web_monitor.RegisterHandlers(whs, web_monitor.WebHandleReload, m.reloadHandlers()) + if err != nil { + return fmt.Errorf("%s.Init(): RegisterHandlers(m.reloadHandlerr): %v", m.name, err) + } + + return nil +} + +func usedQuotaKey(key string, updatetime int64) string { + return fmt.Sprintf("usedquota_%s:%d", key, updatetime) +} + +type TokenAuthContext struct { + Token *Token + PromptTokens int64 // number of tokens in the prompt + CompletionTokens int64 // number of tokens in the completion + UsedQuota int64 // used quota for this request +} +const REQ_TOKEN_AUTH_CONTEXT = "tokenauth_ctx" +func GetTokenAuthContext(req *bfe_basic.Request) *TokenAuthContext { + ctx := req.GetContext(REQ_TOKEN_AUTH_CONTEXT) + tokenCtx, ok := ctx.(*TokenAuthContext) + if !ok { + return nil + } + + return tokenCtx +} +// SetTokenAuthContext sets the token authentication context in the request +func SetTokenAuthContext(req *bfe_basic.Request, tokenCtx *TokenAuthContext) { + req.SetContext(REQ_TOKEN_AUTH_CONTEXT, tokenCtx) +} + +func GetPromptToken(req *bfe_basic.Request) int64 { + // get prompt token from request body + // just a simple implementation here, only consider content length + // just a simple estimation: 1 token ~ 4 bytes + if req.HttpRequest.ContentLength > 0 { + return req.HttpRequest.ContentLength / 4 + } + + // if content length is not set, try to peek the body + bodyAccessor, _ := req.HttpRequest.GetBodyAccessor() + if bodyAccessor == nil { + return 0 + } + + body, _ := bodyAccessor.GetBytes() + return int64(len(body)) / 4 +} diff --git a/bfe_modules/mod_ai_token_auth/token.go b/bfe_modules/mod_ai_token_auth/token.go new file mode 100644 index 000000000..2e7cbeb60 --- /dev/null +++ b/bfe_modules/mod_ai_token_auth/token.go @@ -0,0 +1,140 @@ +// Copyright (c) 2025 The BFE Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mod_ai_token_auth + +import ( + "errors" + "fmt" + "net" + "strings" + "sync/atomic" + + "github.com/google/uuid" +) + +const ( + TokenStatusEnabled = 1 + TokenStatusDisabled = 2 + TokenStatusExpired = 3 + TokenStatusExhausted = 4 +) + +const ( + ActionCheckToken = "CHECK_TOKEN" +) + +type Token struct { + Key string + Status int + Name string + UpdateTime int64 + ExpiredTime int64 + RemainQuota int64 + UsedQuota *atomic.Uint64 + UnlimitedQuota bool + Models []string + Subnet []*net.IPNet +} + +type TokenFile struct { + Key string `json:"key"` + Status int `json:"status"` + Name string `json:"name"` + UpdateTime int64 `json:"update_time"` + ExpiredTime int64 `json:"expired_time"` // -1 means never expired + RemainQuota int64 `json:"remain_quota"` + UnlimitedQuota bool `json:"unlimited_quota"` + Models *string `json:"models"` // allowed models + Subnet *string `json:"subnet"` // allowed subnet + models []string + subnet []*net.IPNet +} + +func tokenCheck(conf *TokenFile) error { + if conf.Key == "" { + return errors.New("no Key") + } + if conf.Status < TokenStatusEnabled || conf.Status > TokenStatusExhausted { + return fmt.Errorf("invalid Status: %d", conf.Status) + } + if conf.ExpiredTime < -1 { + return fmt.Errorf("invalid ExpiredTime: %d", conf.ExpiredTime) + } + if conf.RemainQuota < 0 { + return fmt.Errorf("invalid RemainQuota: %d", conf.RemainQuota) + } + if conf.UnlimitedQuota && conf.RemainQuota != 0 { + return errors.New("if UnlimitedQuota is true, RemainQuota must be 0") + } + if conf.Models != nil { + conf.models = strings.Split(*conf.Models, ",") + for i := 0; i < len(conf.models); i++ { + conf.models[i] = strings.TrimSpace(conf.models[i]) + if conf.models[i] == "" { + return errors.New("Models cannot contain empty strings") + } + } + } + if conf.Subnet != nil { + res := strings.Split(*conf.Subnet, ",") + conf.subnet = make([]*net.IPNet, len(res)) + for i := 0; i < len(res); i++ { + res[i] = strings.TrimSpace(res[i]) + _, subnet, err := net.ParseCIDR(res[i]) + if err != nil { + return fmt.Errorf("invalid subnet %s: %v", res[i], err) + } + conf.subnet[i] = subnet + } + } + return nil +} + +func tokenConvert(tokenFile TokenFile) Token { + return Token{ + Key: tokenFile.Key, + Status: tokenFile.Status, + Name: tokenFile.Name, + UpdateTime: tokenFile.UpdateTime, + ExpiredTime: tokenFile.ExpiredTime, + RemainQuota: tokenFile.RemainQuota, + UnlimitedQuota: tokenFile.UnlimitedQuota, + Models: tokenFile.models, + Subnet: tokenFile.subnet, + } +} + +type ActionFile struct { + Cmd string +} + +type Action ActionFile + +func ActionFileCheck(conf *ActionFile) error { + if conf.Cmd != ActionCheckToken { + return fmt.Errorf("invalid cmd: %s", conf.Cmd) + } + return nil +} + +func actionConvert(actionFile ActionFile) Action { + return Action(actionFile) +} + +func GetUUID() string { + code := uuid.New().String() + code = strings.Replace(code, "-", "", -1) + return code +} diff --git a/bfe_modules/mod_ai_token_auth/token_rule_load.go b/bfe_modules/mod_ai_token_auth/token_rule_load.go new file mode 100644 index 000000000..9cf632530 --- /dev/null +++ b/bfe_modules/mod_ai_token_auth/token_rule_load.go @@ -0,0 +1,235 @@ +// Copyright (c) 2025 The BFE Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mod_ai_token_auth + +import ( + "encoding/json" + "errors" + "fmt" + "os" +) + +import ( + "github.com/bfenetworks/bfe/bfe_basic/condition" +) + +type tokenRuleFile struct { + Cond *string + Action *ActionFile +} + +type tokenRule struct { + Cond condition.Condition + Action Action +} + +type tokenFileMap map[string]*TokenFile +type tokenMap map[string]*Token + +type ProductTokenFiles map[string]*tokenFileMap +type ProductTokens map[string]*tokenMap + +type tokenRuleFileList []tokenRuleFile +type tokenRuleList []tokenRule + +type ProductRulesFile map[string]*tokenRuleFileList +type ProductRules map[string]*tokenRuleList + +type productRuleConfFile struct { + Version *string + Tokens *ProductTokenFiles + Config *ProductRulesFile +} + +type productRuleConf struct { + Version string + Tokens ProductTokens + Config ProductRules +} + +func tokenMapCheck(conf *tokenFileMap) error { + if conf == nil { + return errors.New("no tokenMap") + } + + for key, token := range *conf { + if err := tokenCheck(token); err != nil { + return fmt.Errorf("token %s: %v", key, err) + } + } + + return nil +} +func productTokensCheck(conf *ProductTokenFiles) error { + for product, tokenMap := range *conf { + if err := tokenMapCheck(tokenMap); err != nil { + return fmt.Errorf("ProductTokens %s: %v", product, err) + } + } + + return nil +} + +func tokenRuleCheck(conf tokenRuleFile) error { + if conf.Cond == nil { + return errors.New("no Cond") + } + + if conf.Action == nil { + return errors.New("no Action") + } + if err := ActionFileCheck(conf.Action); err != nil { + return err + } + + return nil +} + +func tokenRuleListCheck(conf *tokenRuleFileList) error { + for index, rule := range *conf { + err := tokenRuleCheck(rule) + if err != nil { + return fmt.Errorf("tokenRule: %d, %v", index, err) + } + } + + return nil +} + +func productRulesCheck(conf *ProductRulesFile) error { + for product, ruleList := range *conf { + if ruleList == nil { + return fmt.Errorf("no tokenRuleList for product: %s", product) + } + + err := tokenRuleListCheck(ruleList) + if err != nil { + return fmt.Errorf("ProductRules: %s, %v", product, err) + } + } + + return nil +} + +func productRuleConfCheck(conf productRuleConfFile) error { + var err error + + if conf.Version == nil { + return errors.New("no Version") + } + + if conf.Config == nil { + return errors.New("no Config") + } + + if conf.Tokens == nil { + return errors.New("no Tokens") + } + + err = productTokensCheck(conf.Tokens) + if err != nil { + return fmt.Errorf("tokens: %v", err) + } + + err = productRulesCheck(conf.Config) + if err != nil { + return fmt.Errorf("config: %v", err) + } + + return nil +} + +func ruleConvert(ruleFile tokenRuleFile) (tokenRule, error) { + rule := tokenRule{} + + cond, err := condition.Build(*ruleFile.Cond) + if err != nil { + return rule, err + } + rule.Cond = cond + rule.Action = actionConvert(*ruleFile.Action) + return rule, nil +} + +func ruleListConvert(ruleFileList *tokenRuleFileList) (*tokenRuleList, error) { + var ruleList tokenRuleList + + for _, ruleFile := range *ruleFileList { + rule, err := ruleConvert(ruleFile) + if err != nil { + return nil, err + } + ruleList = append(ruleList, rule) + } + + return &ruleList, nil +} + +func tokenMapConvert(tokenFileMap *tokenFileMap) (*tokenMap, error) { + tokenMap := make(tokenMap) + + for key, tokenFile := range *tokenFileMap { + token := tokenConvert(*tokenFile) + tokenMap[key] = &token + } + + return &tokenMap, nil +} + +func ProductRuleConfLoad(filename string) (productRuleConf, error) { + var conf productRuleConf + var err error + + file, err := os.Open(filename) + if err != nil { + return conf, err + } + defer file.Close() + + decoder := json.NewDecoder(file) + var config productRuleConfFile + err = decoder.Decode(&config) + if err != nil { + return conf, err + } + + err = productRuleConfCheck(config) + if err != nil { + return conf, err + } + + conf.Version = *config.Version + conf.Config = make(ProductRules) + for product, ruleFileList := range *config.Config { + ruleList, err := ruleListConvert(ruleFileList) + if err != nil { + return conf, err + } + conf.Config[product] = ruleList + } + + conf.Tokens = make(ProductTokens) + if config.Tokens != nil { + for product, tokenMap := range *config.Tokens { + tokenMap, err := tokenMapConvert(tokenMap) + if err != nil { + return conf, err + } + conf.Tokens[product] = tokenMap + } + } + + return conf, nil +} diff --git a/bfe_modules/mod_ai_token_auth/token_rule_table.go b/bfe_modules/mod_ai_token_auth/token_rule_table.go new file mode 100644 index 000000000..c6c75d4e4 --- /dev/null +++ b/bfe_modules/mod_ai_token_auth/token_rule_table.go @@ -0,0 +1,241 @@ +// Copyright (c) 2025 The BFE Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mod_ai_token_auth + +import ( + "errors" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/bfenetworks/bfe/bfe_basic" + "github.com/bfenetworks/bfe/bfe_basic/condition" +) + +type TokenRuleTable struct { + lock sync.RWMutex + version string + productRules ProductRules + productTokens ProductTokens +} + +func NewTokenRuleTable() *TokenRuleTable { + t := new(TokenRuleTable) + t.productRules = make(ProductRules) + t.productTokens = make(ProductTokens) + return t +} + +func (t *TokenRuleTable) Update(conf productRuleConf) (oldtokens []*Token) { + // check token update time, if the token is not updated, we keep the old used quota + for prod, tokenmap := range t.productTokens { + newTokenMap, ok := conf.Tokens[prod] + if !ok { + // product not in new conf, all these tokens are removed + for _, t := range *tokenmap { + oldtokens = append(oldtokens, t) + } + } else { + // product in new conf, check each token + for k, t := range *tokenmap { + newToken, ok := (*newTokenMap)[k] + if !ok { + // token not in new conf, remove it + oldtokens = append(oldtokens, t) + } else if t.UpdateTime == newToken.UpdateTime { + // token not updated, keep the old used quota + newToken.UsedQuota = t.UsedQuota + } else { + // token updated, reset used quota + newToken.UsedQuota = &atomic.Uint64{} + oldtokens = append(oldtokens, t) + } + } + } + } + // init new tokens' UsedQuota + for _, tokenmap := range conf.Tokens { + for _, t := range *tokenmap { + if t.UsedQuota == nil { + t.UsedQuota = &atomic.Uint64{} + } + } + } + + t.lock.Lock() + t.version = conf.Version + t.productRules = conf.Config + t.productTokens = conf.Tokens + t.lock.Unlock() + + return +} + +func (t *TokenRuleTable) Search(product string) (*tokenRuleList, bool) { + t.lock.RLock() + productRules := t.productRules + t.lock.RUnlock() + + rules, ok := productRules[product] + return rules, ok +} + +func (t *TokenRuleTable) GetToken(product, key string) (*Token, bool) { + t.lock.RLock() + tokenMap := t.productTokens[product] + t.lock.RUnlock() + + if tokenMap == nil { + return nil, false + } + tok, ok := (*tokenMap)[key] + return tok, ok +} + +func (t *TokenRuleTable) ValidateUserToken(product, key string) (token *Token, err error) { + if key == "" { + return nil, errors.New("no token") + } + var ok bool + token, ok = t.GetToken(product, key) + if !ok { + return nil, errors.New("token not found") + } + + switch token.Status { + case TokenStatusExhausted: + return nil, fmt.Errorf("token %s quota exhausted", token.Name) + case TokenStatusExpired: + return nil, fmt.Errorf("token %s expired", token.Name) + case TokenStatusDisabled: + return nil, fmt.Errorf("token %s disabled", token.Name) + } + + if token.ExpiredTime != -1 && token.ExpiredTime < time.Now().Unix() { + token.Status = TokenStatusExpired + return nil, fmt.Errorf("token %s expired", token.Name) + } + + if !token.UnlimitedQuota && token.RemainQuota <= 0 { + token.Status = TokenStatusExhausted + return nil, fmt.Errorf("token %s quota exhausted", token.Name) + } + return token, nil +} + +func (m *ModuleAITokenAuth) ValidateUserTokenByReq(req *bfe_basic.Request) (token *Token, err error) { + key := GetApiKey(req) + if key == "" { + return nil, errors.New("no token") + } + product := req.Route.Product + if product == "" { + return nil, errors.New("no product") + } + + var ok bool + token, ok = m.ruleTable.GetToken(product, key) + if !ok { + return nil, errors.New("token not found") + } + + switch token.Status { + case TokenStatusExhausted: + return nil, fmt.Errorf("token %s quota exhausted", token.Name) + case TokenStatusExpired: + return nil, fmt.Errorf("token %s expired", token.Name) + case TokenStatusDisabled: + return nil, fmt.Errorf("token %s disabled", token.Name) + } + + if token.ExpiredTime != -1 && token.ExpiredTime < time.Now().Unix() { + token.Status = TokenStatusExpired + return nil, fmt.Errorf("token %s expired", token.Name) + } + + if !token.UnlimitedQuota { + if token.RemainQuota <= 0 { + token.Status = TokenStatusExhausted + return nil, fmt.Errorf("token %s quota exhausted", token.Name) + } else { + used := m.GetTokenUsedQuota(token) + if used >= token.RemainQuota { + token.Status = TokenStatusExhausted + return nil, fmt.Errorf("token %s quota exhausted", token.Name) + } + } + } + + if len(token.Models) > 0 { + model, err := condition.ReqBodyJsonFetch(req, "model") + if err != nil || model == "" { + return nil, fmt.Errorf("model not found in request body: %v", err) + } + model = strings.TrimSpace(model) + inModels := false + for _, m := range token.Models { + if m == model { + inModels = true + break + } + } + if !inModels { + return nil, fmt.Errorf("model %s not allowed by token %s", model, token.Name) + } + } + + if len(token.Subnet) > 0 { + inSubnet := false + for _, subnet := range token.Subnet { + if req.ClientAddr != nil && subnet.Contains(req.ClientAddr.IP) { + inSubnet = true + break + } else if req.RemoteAddr != nil && subnet.Contains(req.RemoteAddr.IP) { + inSubnet = true + break + } + } + if !inSubnet { + return nil, fmt.Errorf("client IP not in subnet of token %s", token.Name) + } + } + return token, nil +} + +func (m *ModuleAITokenAuth) GetTokenUsedQuota(t *Token) int64 { + if t == nil { + return 0 + } + key := usedQuotaKey(t.Key, t.UpdateTime) + val, err := m.redisClient.GetInt64(key) + if err != nil { + return 0 + } + return val +} + +func (m *ModuleAITokenAuth) IncrTokenUsedQuotaBy(t *Token, delta int64) int64 { + if t == nil { + return 0 + } + key := usedQuotaKey(t.Key, t.UpdateTime) + val, err := m.redisClient.IncrBy(key, delta) + if err != nil { + return 0 + } + return val +} diff --git a/bfe_modules/mod_body_process/body_process.go b/bfe_modules/mod_body_process/body_process.go new file mode 100644 index 000000000..f8bdd76fc --- /dev/null +++ b/bfe_modules/mod_body_process/body_process.go @@ -0,0 +1,588 @@ +// Copyright (c) 2025 The BFE Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mod_body_process + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "strings" + + "github.com/bfenetworks/bfe/bfe_basic" + "github.com/bfenetworks/bfe/bfe_http" + "github.com/bfenetworks/bfe/bfe_modules/mod_ai_token_auth" +) + +// BodyProcessor 扩展中断支持 +type BodyProcessor struct { + source io.ReadCloser + buffer *bytes.Buffer + decoder EventDecoder + processors []EventProcessor + encoder EventEncoder + // mu sync.Mutex + // closed bool + err error + rejection *RejectionError // 中断时存储的违规信息 + + // 中断时回调 + onReject func(error, *BodyProcessor) +} + +// NewBodyProcessor 创建处理器 +func NewBodyProcessor(source io.ReadCloser) *BodyProcessor { + return &BodyProcessor{ + source: source, + buffer: bytes.NewBuffer(nil), + } +} + +func (bp *BodyProcessor) GetSource() io.ReadCloser { + // bp.mu.Lock() + // defer bp.mu.Unlock() + return bp.source +} + +// 注册中断回调 +func (bp *BodyProcessor) OnReject(fn func(error, *BodyProcessor)) { + // bp.mu.Lock() + // defer bp.mu.Unlock() + bp.onReject = fn +} + +// RejectionError 自定义错误类型 +type RejectionError struct { + Message string + StatusCode int + // RejectionResponse func(http.ResponseWriter) // 自定义响应生成器 +} + +func (e *RejectionError) Error() string { + return e.Message +} + +type Event interface { + // GetType() string + // GetData() []byte + ToBytes() []byte // 转换为字节数组 +} + +type EventDecoder interface { + // return: + // events, nil - len(events) > 0, success + // events, nil - len(events) = 0, 没有更多数据, eof + // nil, error - 发生错误 + Decode() ([]Event, error) +} + +type EventDecoderFac func(source io.Reader) (EventDecoder, error) +type EventDecoderFacWithReq func(source io.Reader, req bfe_basic.Request) (EventDecoder, error) + +type EventEncoder interface { + Encode(events []Event) (int, error) +} + +type EventEncoderFac func(dest io.Writer) (EventEncoder, error) + +type EventProcessor interface { + Process(events []Event) ([]Event, error) +} + +func (bp *BodyProcessor) CreateEventDecoder(fac EventDecoderFac) { + // bp.mu.Lock() + // defer bp.mu.Unlock() + dec, err := fac(bp.source) + if err != nil { + bp.err = fmt.Errorf("create event decoder: %w", err) + return + } + bp.decoder = dec +} + +func (bp *BodyProcessor) CreateEventEncoder(fac EventEncoderFac) { + // bp.mu.Lock() + // defer bp.mu.Unlock() + enc, err := fac(bp.buffer) + if err != nil { + bp.err = fmt.Errorf("create event encoder: %w", err) + return + } + bp.encoder = enc +} + +func (bp *BodyProcessor) AddProcessor(p EventProcessor) { + // bp.mu.Lock() + // defer bp.mu.Unlock() + + bp.processors = append(bp.processors, p) +} + +// ProcessorFunc 简化处理器实现 +type EventProcessorFunc func([]Event) ([]Event, error) + +func (f EventProcessorFunc) Process(events []Event) ([]Event, error) { + return f(events) +} + +// Read 实现io.Reader接口(支持中断) +func (bp *BodyProcessor) Read(p []byte) (n int, err error) { + // bp.mu.Lock() + // defer bp.mu.Unlock() + + // if bp.rejection != nil { + // return 0, bp.rejection // 返回违规错误 + // } + + if bp.err != nil && bp.err != io.EOF { + return 0, bp.err + } + + // 检查缓冲区是否足够 + if bp.buffer.Len() < len(p) && bp.err != io.EOF { + if err := bp.fillBuffer(); err != nil { + return 0, err + } + } + + return bp.buffer.Read(p) +} + +// fillBuffer 实现内容审查和中断 +func (bp *BodyProcessor) fillBuffer() error { + for { + events, decodeErr := bp.decoder.Decode() + if decodeErr != nil { + bp.err = decodeErr + return decodeErr + } + if len(events) == 0 { + bp.err = io.EOF + // eof is not an error for fillbuffer, just break the loop + break + } + // 处理事件 + for _, processor := range bp.processors { + if len(events) == 0 { + break // 没有事件可处理 + } + var processErr error + events, processErr = processor.Process(events) + if processErr != nil { + bp.err = processErr + // 检查是否为中断错误 + if cvErr, ok := processErr.(*RejectionError); ok { + bp.handleRejection(cvErr) + return cvErr + } + return processErr + } + } + // 编码事件 + n, encodeErr := bp.encoder.Encode(events) + if encodeErr != nil { + bp.err = encodeErr + return encodeErr + } + if n > 0 { + break // 至少有一个事件被处理 + } + } + return nil +} + +// handleRejection 处理内容违规事件 +func (bp *BodyProcessor) handleRejection(err *RejectionError) { + bp.rejection = err + + // 触发回调 + if bp.onReject != nil { + bp.onReject(err, bp) + } +} + +// RejectionResponse 获取中断响应(如在反向代理中使用) +func (bp *BodyProcessor) RejectionResponse() *RejectionError { + // bp.mu.Lock() + // defer bp.mu.Unlock() + return bp.rejection +} + +// Close 实现io.Closer接口 +func (bp *BodyProcessor) Close() error { + // bp.mu.Lock() + // defer bp.mu.Unlock() + + bp.buffer.Reset() + return bp.source.Close() +} + +// FillBuffer 公开的缓冲区填充方法 +// 安全地从源读取并处理一个数据块 +func (bp *BodyProcessor) FillBuffer() error { + // bp.mu.Lock() + // defer bp.mu.Unlock() + + if bp.err != nil { + return bp.err + } + + return bp.fillBuffer() +} + +func (m *ModuleBodyProcess) DoRequestProcess(req *bfe_basic.Request, conf *BodyProcessConfig) *BodyProcessor { + if conf == nil { + return nil // 没有配置,直接返回 + } + + m.state.ReqProcess.Inc(1) + + bp := NewBodyProcessor(req.HttpRequest.Body) + switch conf.Dec { + // case "sse": // sse is not available for request body + // bp.CreateEventDecoder(NewSSEEventDecoder) + case "line": + bp.CreateEventDecoder(NewLineDecoder) + case "json": + bp.CreateEventDecoder(NewJsonDecoder) + default: + contentType := req.HttpRequest.Header.Get("Content-Type") + bp.CreateEventDecoder(func(source io.Reader)(EventDecoder, error) { + return NewContentTypeDecoder(source, contentType)} ) // 使用ContentTypeDecoder根据Content-Type自动选择解码器 + // bp.CreateEventDecoder(NewJsonDecoder) // 默认使用ndJson解码 + } + bp.CreateEventEncoder(NewGeneralEncoder) + for _, proc := range conf.Proc { + switch proc.Name { + case "textfilter": + caf, _ := NewContentAudit(proc.Params[0], false) + bp.AddProcessor(caf) + } + } + + req.HttpRequest.Body = bp + req.HttpRequest.ContentLength = -1 // 设置为-1表示不确定长度 + req.HttpRequest.Header.Del("Content-Length") + return bp +} + +func (m *ModuleBodyProcess) DoResponseProcess(req *bfe_basic.Request, res *bfe_http.Response, conf *BodyProcessConfig) *BodyProcessor { + // 检查是否需要处理streamcompletion + ccq := NewCalcCompletionQuota(req) + + if conf == nil && ccq == nil { + return nil // 没有配置,直接返回 + } + + m.state.ResProcess.Inc(1) + + bp := NewBodyProcessor(res.Body) + // 缺省添加streamcompletion处理器 + if ccq != nil { + bp.AddProcessor(ccq) + } + + var dec string + if conf != nil { + dec = conf.Dec + } + + switch dec { + case "sse": // sse is not available for request body + bp.CreateEventDecoder(NewSSEEventDecoder) + case "line": + bp.CreateEventDecoder(NewLineDecoder) + case "json": + bp.CreateEventDecoder(NewJsonDecoder) + default: + contentType := res.Header.Get("Content-Type") + bp.CreateEventDecoder(func(source io.Reader)(EventDecoder, error) { + return NewContentTypeDecoder(source, contentType)} ) // 使用ContentTypeDecoder根据Content-Type自动选择解码器 + // bp.CreateEventDecoder(NewJsonDecoder) // 默认使用ndJson解码 + } + + bp.CreateEventEncoder(NewGeneralEncoder) + + if conf != nil { + for _, proc := range conf.Proc { + switch proc.Name { + case "textfilter": + caf, _ := NewContentAudit(proc.Params[0], true) + bp.AddProcessor(caf) + } + } + } + + res.Body = bp + res.ContentLength = -1 // 设置为-1表示不确定长度 + res.Header.Del("Content-Length") + return bp +} +/* +func (m *ModuleBodyProcess) DoResponseProcess(req *bfe_basic.Request, res *bfe_http.Response, conf *BodyProcessConfig) *BodyProcessor { + if conf == nil { + return nil // 没有配置,直接返回 + } + + m.state.ResProcess.Inc(1) + + bp := NewBodyProcessor(res.Body) + switch conf.Dec { + case "sse": // sse is not available for request body + bp.CreateEventDecoder(NewSSEEventDecoder) + case "line": + bp.CreateEventDecoder(NewLineDecoder) + case "json": + bp.CreateEventDecoder(NewJsonDecoder) + default: + contentType := res.Header.Get("Content-Type") + bp.CreateEventDecoder(func(source io.Reader)(EventDecoder, error) { + return NewContentTypeDecoder(source, contentType)} ) // 使用ContentTypeDecoder根据Content-Type自动选择解码器 + // bp.CreateEventDecoder(NewJsonDecoder) // 默认使用ndJson解码 + } + + bp.CreateEventEncoder(NewGeneralEncoder) + + // 缺省添加streamcompletion处理器 + p := NewCalcCompletionQuota(req) + if p != nil { + bp.AddProcessor(p) + } + + for _, proc := range conf.Proc { + switch proc.Name { + case "textfilter": + caf, _ := NewContentAudit(proc.Params[0], true) + bp.AddProcessor(caf) + } + } + + res.Body = bp + res.ContentLength = -1 // 设置为-1表示不确定长度 + res.Header.Del("Content-Length") + return bp +} +*/ +// SSEEvent 表示一个SSE事件 +type SSEEvent struct { + ID string + Event string + Data []byte + Retry int + // raw []byte // 原始事件数据 + truncated bool // 是否被截断 +} + +// ToBytes 将事件转换为SSE格式 +func (e *SSEEvent) ToBytes() []byte { + var buf bytes.Buffer + if e.ID != "" { + buf.WriteString("id: " + e.ID + "\n") + } + if e.Event != "" { + buf.WriteString("event: " + e.Event + "\n") + } + if len(e.Data) > 0 { + lines := strings.Split(string(e.Data), "\n") + for _, line := range lines { + buf.WriteString("data: " + line + "\n") + } + } + if e.Retry > 0 { + buf.WriteString(fmt.Sprintf("retry: %d\n", e.Retry)) + } + if !e.truncated { + buf.WriteString("\n") + } + return buf.Bytes() +} + +type GeneralEncoder struct { + dest io.Writer +} + +func NewGeneralEncoder(dest io.Writer) (EventEncoder, error) { + return &GeneralEncoder{dest: dest}, nil +} + +func (enc *GeneralEncoder) Encode(events []Event) (int, error) { + var total int + for _, event := range events { + data := event.ToBytes() + n, err := enc.dest.Write(data) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +type SSEEventDecoder struct { + scanner *bufio.Scanner +} + +func NewSSEEventDecoder(source io.Reader) (EventDecoder, error) { + scanner := bufio.NewScanner(source) + return &SSEEventDecoder{scanner: scanner}, nil +} + +func (dec *SSEEventDecoder) Decode() ([]Event, error) { + var current SSEEvent + dataLines := []string{} + for dec.scanner.Scan() { + line := dec.scanner.Text() + if line == "" { + // 空行表示一个完整的事件结束 + if len(dataLines) == 0 && current.Event == "" && current.ID == "" && len(current.Data) == 0 { + continue // 跳过空事件 + } + current.Data = []byte(strings.Join(dataLines, "\n")) + return []Event{¤t}, nil + } + + // 解析SSE事件 + if strings.HasPrefix(line, "event:") { + current.Event = strings.TrimSpace(line[6:]) + } else if strings.HasPrefix(line, "data:") { + dataLines = append(dataLines, strings.TrimSpace(line[5:])) + } else if strings.HasPrefix(line, "id:") { + current.ID = strings.TrimSpace(line[3:]) + } else if strings.HasPrefix(line, "retry:") { + var retry int + _, err := fmt.Sscanf(line[6:], "%d", &retry) + if err != nil { + return nil, fmt.Errorf("invalid retry value: %s", line[6:]) + } + current.Retry = retry + } else { + // 未知的SSE行,可能需要处理或忽略 + return nil, fmt.Errorf("unknown SSE line: %s", line) + } + } + // 检查是否有未完成的事件 + if len(dataLines) > 0 || current.Event != "" || current.ID != "" { + current.Data = []byte(strings.Join(dataLines, "\n")) + current.truncated = true // 标记为被截断 + return []Event{¤t}, nil + } + + return []Event{}, dec.scanner.Err() +} + +type RawEvent []byte + +func (e *RawEvent) ToBytes() []byte { + return *e +} + +type LineDecoder struct { + reader *bufio.Reader +} + +func NewLineDecoder(source io.Reader) (EventDecoder, error) { + reader := bufio.NewReader(source) + return &LineDecoder{reader: reader}, nil +} + +func (dec *LineDecoder) Decode() ([]Event, error) { + line, err := dec.reader.ReadBytes('\n') + if len(line) != 0 { + re := RawEvent(line) + return []Event{&re}, nil + } + if err == io.EOF { + return []Event{}, nil // 没有更多数据 + } + return nil, fmt.Errorf("line decode error: %w", err) +} + +type JsonDecoder struct { + dec *json.Decoder +} + +func NewJsonDecoder(source io.Reader) (EventDecoder, error) { + dec := json.NewDecoder(source) + return &JsonDecoder{dec: dec}, nil +} + +func (dec *JsonDecoder) Decode() ([]Event, error) { + var event json.RawMessage + // 尝试解码一个JSON对象 + if err := dec.dec.Decode(&event); err != nil { + if err == io.EOF { + return []Event{}, nil // 没有更多数据 + } + return nil, fmt.Errorf("json decode error: %w", err) + } + re := RawEvent(event) + return []Event{&re}, nil +} + +type ContentTypeDecoder struct { + contentType string + dec EventDecoder +} + +func NewContentTypeDecoder(source io.Reader, contentType string) (EventDecoder, error) { + var dec EventDecoder + switch contentType { + case "application/sse", "text/event-stream", "application/x-sse": + dec, _ = NewSSEEventDecoder(source) + case "application/json", "application/ndjson", "application/x-ndjson": + dec, _ = NewJsonDecoder(source) // ndjson is a line-delimited JSON, can use JsonDecoder + default: + dec, _ = NewLineDecoder(source) // 默认使用行解码器 + } + + return &ContentTypeDecoder{contentType: contentType, dec: dec}, nil +} + +func (ctd *ContentTypeDecoder) Decode() ([]Event, error) { + return ctd.dec.Decode() +} + +func GetEventTokens(ev Event) int64 { + if ev == nil { + return 0 + } + + switch e := ev.(type) { + case *RawEvent: + return int64(len(*e)/4) + case *SSEEvent: + return int64(len(e.Data)/4) + default: + return 0 + } +} + +func NewCalcCompletionQuota(req *bfe_basic.Request) EventProcessorFunc { + ctx := mod_ai_token_auth.GetTokenAuthContext(req) + if ctx == nil || ctx.CompletionTokens != -1 { + return nil // 没有token上下文,或 CompletionTokens 已知,无需计算 + } + return func(events []Event) ([]Event, error) { + for _, ev := range events { + if ctx.CompletionTokens == -1 { + ctx.CompletionTokens = 0 // 初始化为0 + } + // 累加事件的token数 + ctx.CompletionTokens += GetEventTokens(ev) + } + return events, nil // 没有事件,直接返回 + } +} diff --git a/bfe_modules/mod_body_process/body_process_rule_load.go b/bfe_modules/mod_body_process/body_process_rule_load.go new file mode 100644 index 000000000..4c3f9c574 --- /dev/null +++ b/bfe_modules/mod_body_process/body_process_rule_load.go @@ -0,0 +1,195 @@ +// Copyright (c) 2025 The BFE Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mod_body_process + +import ( + "encoding/json" + "errors" + "fmt" + "os" +) + +import ( + "github.com/bfenetworks/bfe/bfe_basic/condition" +) + +type ProcConf struct { + Name string + Params []string +} +type BodyProcessConfig struct { + Dec string + Enc string + Proc []ProcConf // processing steps +} + +func BodyProcessConfigCheck(config *BodyProcessConfig) error { + return nil +} + +type processRuleFile struct { + Cond *string + RequestProcess *BodyProcessConfig + ResponseProcess *BodyProcessConfig +} + +type processRule struct { + Cond condition.Condition + RequestProcess *BodyProcessConfig + ResponseProcess *BodyProcessConfig +} + +type processRuleFileList []processRuleFile +type processRuleList []processRule + +type ProductRulesFile map[string]processRuleFileList +type ProductRules map[string]processRuleList + +type productRuleConfFile struct { + Version *string + Config *ProductRulesFile +} + +type productRuleConf struct { + Version string + Config ProductRules +} + +func processRuleCheck(conf processRuleFile) error { + if conf.Cond == nil { + return errors.New("no Cond") + } + + if conf.RequestProcess == nil && conf.ResponseProcess == nil { + return errors.New("no RequestProcess or ResponseProcess") + } + + if err := BodyProcessConfigCheck(conf.RequestProcess); err != nil { + return err + } + + if err := BodyProcessConfigCheck(conf.ResponseProcess); err != nil { + return err + } + + return nil +} + +func processRuleListCheck(conf processRuleFileList) error { + for index, rule := range conf { + err := processRuleCheck(rule) + if err != nil { + return fmt.Errorf("processRule: %d, %v", index, err) + } + } + + return nil +} + +func productRulesCheck(conf *ProductRulesFile) error { + for product, ruleList := range *conf { + if ruleList == nil { + return fmt.Errorf("no tokenRuleList for product: %s", product) + } + + err := processRuleListCheck(ruleList) + if err != nil { + return fmt.Errorf("ProductRules: %s, %v", product, err) + } + } + + return nil +} + +func productRuleConfCheck(conf productRuleConfFile) error { + var err error + + if conf.Version == nil { + return errors.New("no Version") + } + + if conf.Config == nil { + return errors.New("no Config") + } + + err = productRulesCheck(conf.Config) + if err != nil { + return fmt.Errorf("config: %v", err) + } + + return nil +} + +func ruleConvert(ruleFile processRuleFile) (processRule, error) { + rule := processRule{} + + cond, err := condition.Build(*ruleFile.Cond) + if err != nil { + return rule, err + } + rule.Cond = cond + rule.RequestProcess = ruleFile.RequestProcess + rule.ResponseProcess = ruleFile.ResponseProcess + return rule, nil +} + +func ruleListConvert(ruleFileList processRuleFileList) (processRuleList, error) { + var ruleList processRuleList + + for _, ruleFile := range ruleFileList { + rule, err := ruleConvert(ruleFile) + if err != nil { + return nil, err + } + ruleList = append(ruleList, rule) + } + + return ruleList, nil +} + +func ProductRuleConfLoad(filename string) (productRuleConf, error) { + var conf productRuleConf + var err error + + file, err := os.Open(filename) + if err != nil { + return conf, err + } + defer file.Close() + + decoder := json.NewDecoder(file) + var config productRuleConfFile + err = decoder.Decode(&config) + if err != nil { + return conf, err + } + + err = productRuleConfCheck(config) + if err != nil { + return conf, err + } + + conf.Version = *config.Version + conf.Config = make(ProductRules) + for product, ruleFileList := range *config.Config { + ruleList, err := ruleListConvert(ruleFileList) + if err != nil { + return conf, err + } + conf.Config[product] = ruleList + } + + return conf, nil +} diff --git a/bfe_modules/mod_body_process/body_process_rule_table.go b/bfe_modules/mod_body_process/body_process_rule_table.go new file mode 100644 index 000000000..dfa0e0675 --- /dev/null +++ b/bfe_modules/mod_body_process/body_process_rule_table.go @@ -0,0 +1,48 @@ +// Copyright (c) 2025 The BFE Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mod_body_process + +import ( + "sync" +) + +type ProcessRuleTable struct { + lock sync.RWMutex + version string + productRules ProductRules +} + +func NewTokenRuleTable() *ProcessRuleTable { + t := new(ProcessRuleTable) + t.productRules = make(ProductRules) + return t +} + +func (t *ProcessRuleTable) Update(conf productRuleConf) { + t.lock.Lock() + t.version = conf.Version + t.productRules = conf.Config + t.lock.Unlock() +} + +func (t *ProcessRuleTable) Search(product string) (processRuleList, bool) { + t.lock.RLock() + productRules := t.productRules + t.lock.RUnlock() + + rules, ok := productRules[product] + return rules, ok +} + diff --git a/bfe_modules/mod_body_process/conf_mod_body_process.go b/bfe_modules/mod_body_process/conf_mod_body_process.go new file mode 100644 index 000000000..a58a4de82 --- /dev/null +++ b/bfe_modules/mod_body_process/conf_mod_body_process.go @@ -0,0 +1,62 @@ +// Copyright (c) 2025 The BFE Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mod_body_process + +import ( + "github.com/baidu/go-lib/log" + gcfg "gopkg.in/gcfg.v1" +) + +import ( + "github.com/bfenetworks/bfe/bfe_util" +) + +type ConfModBodyProcess struct { + Basic struct { + ProductRulePath string + } + + Log struct { + OpenDebug bool + } +} + +func (cfg *ConfModBodyProcess) Check(confRoot string) error { + if cfg.Basic.ProductRulePath == "" { + log.Logger.Warn("ModBodyProcess.ProductRulePath not set, use default value") + cfg.Basic.ProductRulePath = "mod_body_process/body_process.data" + } + + cfg.Basic.ProductRulePath = bfe_util.ConfPathProc(cfg.Basic.ProductRulePath, confRoot) + + return nil +} + +func ConfLoad(filePath string, confRoot string) (*ConfModBodyProcess, error) { + var cfg ConfModBodyProcess + var err error + + err = gcfg.ReadFileInto(&cfg, filePath) + if err != nil { + return &cfg, err + } + + err = cfg.Check(confRoot) + if err != nil { + return &cfg, err + } + + return &cfg, nil +} diff --git a/bfe_modules/mod_body_process/content_audit_process.go b/bfe_modules/mod_body_process/content_audit_process.go new file mode 100644 index 000000000..251684d79 --- /dev/null +++ b/bfe_modules/mod_body_process/content_audit_process.go @@ -0,0 +1,173 @@ +// Copyright (c) 2025 The BFE Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mod_body_process + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + "github.com/baidu/go-lib/log" +) + +var ( + httpClient *http.Client + mutex sync.RWMutex +) + +func GetHTTPClient() *http.Client { + mutex.RLock() + defer mutex.RUnlock() + if httpClient == nil { + return &http.Client{} + } + return httpClient +} + +func SetHTTPClient(client *http.Client) { + mutex.Lock() + defer mutex.Unlock() + if client == nil { + httpClient = nil + } else { + httpClient = client + } +} + +func init() { + // initialize the HTTP client with a timeout + SetHTTPClient(&http.Client{ + Timeout: 10 * time.Second, + }) +} + +type ContentAudit struct { + url string + replace bool // true: replace text; false: filter text +} + +func NewContentAudit(urlStr string, replace bool) (*ContentAudit, error) { + if replace { + urlStr = strings.TrimSuffix(urlStr, "/") + "/text-replace" + } else { + urlStr = strings.TrimSuffix(urlStr, "/") + "/text-filter" + } + + return &ContentAudit{url: urlStr, replace: replace}, nil +} + +func GetAuditData(ev Event) ([]byte, error) { + if ev == nil { + return nil, fmt.Errorf("event is nil") + } + + switch e := ev.(type) { + case *RawEvent: + return *e, nil + case *SSEEvent: + return e.Data, nil + default: + return nil, fmt.Errorf("unsupported event type: %T", ev) + } +} + +func SetAuditData(ev Event, data []byte) error { + if ev == nil { + return fmt.Errorf("event is nil") + } + + switch e := ev.(type) { + case *RawEvent: + *e = data + case *SSEEvent: + e.Data = data + default: + return fmt.Errorf("unsupported event type: %T", ev) + } + return nil +} + +func (caf *ContentAudit) Process(evs []Event) ([]Event, error) { + // Return the processed event list and possible errors + client := GetHTTPClient() + for _, ev := range evs { + data, err := GetAuditData(ev) + if err != nil { + log.Logger.Error("failed to get audit data: %v", err) + continue // fail to get data, skip current event + } + resp, err := client.PostForm(caf.url, url.Values{ "txt": {string(data)} }) + if err != nil { + log.Logger.Error("failed to audit content: %v", err) + continue // request failed, skip current event + } + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Logger.Error("failed to read response body: %v", err) + continue // fail to read response, skip current event + } + resp.Body.Close() + var result TextFilterResult + err = json.Unmarshal(body, &result) + if err != nil { + log.Logger.Error("failed to unmarshal response: %v", err) + continue // fail to parse response, skip current event + } + if caf.replace { + if result.ResultText != "" { + err = SetAuditData(ev, []byte(result.ResultText)) + if err != nil { + log.Logger.Error("failed to set audit data: %v", err) + continue // fail to set data, skip current event + } + } + } else { + if result.RiskLevel == "REJECT" || (result.RiskLevel == "REVIEW" && result.SentimentScore < -0.5) { + return nil, fmt.Errorf("content audit failed: %v", result) + } + } + } + return evs, nil +} + +type TextFilterResult struct { + Code int32 + Message string + RequestId string + RiskLevel string + RiskCode string + SentimentScore float32 + ResultText string + + Details []TextFilterDetailItem + Contacts []TextFilterContactItem +} + +type TextFilterDetailItem struct{ + RiskLevel string + RiskCode string + Position string + Text string +} + +type TextFilterContactItem struct { + ContactType string + ContactString string + Position string +} diff --git a/bfe_modules/mod_body_process/mod_body_process.go b/bfe_modules/mod_body_process/mod_body_process.go new file mode 100644 index 000000000..245031584 --- /dev/null +++ b/bfe_modules/mod_body_process/mod_body_process.go @@ -0,0 +1,205 @@ +// Copyright (c) 2025 The BFE Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mod_body_process + +import ( + "fmt" + "net/url" + + "github.com/baidu/go-lib/log" + "github.com/baidu/go-lib/web-monitor/metrics" + "github.com/baidu/go-lib/web-monitor/web_monitor" + + "github.com/bfenetworks/bfe/bfe_basic" + "github.com/bfenetworks/bfe/bfe_http" + "github.com/bfenetworks/bfe/bfe_module" +) + +const ( + ModBodyProcess = "mod_body_process" + BodyProcessResponseConfigKey = "mod_body_process.response_config" +) + +var ( + openDebug = false +) + +type ModuleBodyProcessState struct { + ReqTotal *metrics.Counter + ReqProcess *metrics.Counter + ResProcess *metrics.Counter +} + +type ModuleBodyProcess struct { + name string + conf *ConfModBodyProcess + ruleTable *ProcessRuleTable + state ModuleBodyProcessState + metrics metrics.Metrics +} + +func NewModuleBodyProcess() *ModuleBodyProcess { + m := new(ModuleBodyProcess) + m.name = ModBodyProcess + m.metrics.Init(&m.state, ModBodyProcess, 0) + m.ruleTable = NewTokenRuleTable() + return m +} + +func (m *ModuleBodyProcess) Name() string { + return m.name +} + +func (m *ModuleBodyProcess) loadProductRuleConf(query url.Values) error { + path := query.Get("path") + if path == "" { + path = m.conf.Basic.ProductRulePath + } + + conf, err := ProductRuleConfLoad(path) + if err != nil { + return fmt.Errorf("err in ProductRuleConfLoad(%s): %s", path, err) + } + + m.ruleTable.Update(conf) + return nil +} + +func (m *ModuleBodyProcess) matchProcessRule(req *bfe_basic.Request) *processRule { + if openDebug { + log.Logger.Debug("%s check request", m.name) + } + m.state.ReqTotal.Inc(1) + + rules, ok := m.ruleTable.Search(req.Route.Product) + if !ok { + if openDebug { + log.Logger.Debug("%s product %s not found, just pass", m.name, req.Route.Product) + } + return nil + } + + for _, rule := range rules { + if openDebug { + log.Logger.Debug("%s process rule: %v", m.name, rule) + } + + if rule.Cond.Match(req) { + return &rule + } + } + + return nil +} + +// found product handler +func (m *ModuleBodyProcess) afterLocationHandler(req *bfe_basic.Request) (int, *bfe_http.Response) { + matchedRule := m.matchProcessRule(req) + if matchedRule == nil { + // no rule, just pass + return bfe_module.BfeHandlerGoOn, nil + } + + // add body processor + if openDebug { + log.Logger.Debug("%s found matched rule: %v", m.name, matchedRule) + } + + m.DoRequestProcess(req, matchedRule.RequestProcess) + + if matchedRule.ResponseProcess != nil { + req.SetContext(BodyProcessResponseConfigKey, matchedRule.ResponseProcess) + } + return bfe_module.BfeHandlerGoOn, nil +} + +func (m *ModuleBodyProcess) readResponseHandler(req *bfe_basic.Request, res *bfe_http.Response) int { + var conf *BodyProcessConfig + // get response config from request context + data := req.GetContext(BodyProcessResponseConfigKey) + if data != nil { + var ok bool + conf, ok = data.(*BodyProcessConfig) + if !ok { + log.Logger.Warn("%s: type assertion fail, %v", m.name, data) + } + } + + m.DoResponseProcess(req, res, conf) + + return bfe_module.BfeHandlerGoOn +} + +func (m *ModuleBodyProcess) getState(params map[string][]string) ([]byte, error) { + s := m.metrics.GetAll() + return s.Format(params) +} + +func (m *ModuleBodyProcess) getStateDiff(params map[string][]string) ([]byte, error) { + s := m.metrics.GetDiff() + return s.Format(params) +} + +func (m *ModuleBodyProcess) monitorHandlers() map[string]interface{} { + handlers := map[string]interface{}{ + m.name: m.getState, + m.name + ".diff": m.getStateDiff, + } + return handlers +} + +func (m *ModuleBodyProcess) reloadHandlers() map[string]interface{} { + handlers := map[string]interface{}{ + m.name: m.loadProductRuleConf, + } + return handlers +} + +func (m *ModuleBodyProcess) Init(cbs *bfe_module.BfeCallbacks, whs *web_monitor.WebHandlers, + cr string) error { + var err error + + confPath := bfe_module.ModConfPath(cr, m.name) + if m.conf, err = ConfLoad(confPath, cr); err != nil { + return fmt.Errorf("%s: conf load err %v", m.name, err) + } + openDebug = m.conf.Log.OpenDebug + + if err = m.loadProductRuleConf(nil); err != nil { + return fmt.Errorf("%s: loadProductRuleConf() err %v", m.name, err) + } + + err = cbs.AddFilter(bfe_module.HandleAfterLocation, m.afterLocationHandler) + if err != nil { + return fmt.Errorf("%s.Init(): AddFilter(m.foundProductHandler): %s", m.name, err.Error()) + } + + err = cbs.AddFilter(bfe_module.HandleReadResponse, m.readResponseHandler) + if err != nil { + return fmt.Errorf("%s.Init(): AddFilter(m.readResponseHandler): %v", m.name, err) + } + + err = web_monitor.RegisterHandlers(whs, web_monitor.WebHandleMonitor, m.monitorHandlers()) + if err != nil { + return fmt.Errorf("%s.Init(): RegisterHandlers(m.monitorHandlers): %v", m.name, err) + } + + err = web_monitor.RegisterHandlers(whs, web_monitor.WebHandleReload, m.reloadHandlers()) + if err != nil { + return fmt.Errorf("%s.Init(): RegisterHandlers(m.reloadHandlerr): %v", m.name, err) + } + + return nil +} diff --git a/bfe_route/bfe_cluster/bfe_cluster.go b/bfe_route/bfe_cluster/bfe_cluster.go index 5be369b3d..9220d3df4 100644 --- a/bfe_route/bfe_cluster/bfe_cluster.go +++ b/bfe_route/bfe_cluster/bfe_cluster.go @@ -32,6 +32,7 @@ type BfeCluster struct { CheckConf *cluster_conf.BackendCheck // how to check backend GslbBasic *cluster_conf.GslbBasicConf // gslb basic httpsConf *cluster_conf.BackendHTTPS // https basic + AIConf *cluster_conf.AIConf // ai conf for cluster timeoutReadClient time.Duration // timeout for read client body timeoutReadClientAgain time.Duration // timeout for read client again @@ -58,6 +59,8 @@ func (cluster *BfeCluster) BasicInit(clusterConf cluster_conf.ClusterConf) { // set gslb retry conf cluster.GslbBasic = clusterConf.GslbBasic + cluster.AIConf = clusterConf.AIConf + cluster.timeoutReadClient = time.Duration(*clusterConf.ClusterBasic.TimeoutReadClient) * time.Millisecond cluster.timeoutReadClientAgain = diff --git a/bfe_server/proxy_state.go b/bfe_server/proxy_state.go index d9e299d77..585f8631c 100644 --- a/bfe_server/proxy_state.go +++ b/bfe_server/proxy_state.go @@ -43,7 +43,8 @@ type ProxyState struct { ErrBkFindLocation *metrics.Counter ErrBkNoBalance *metrics.Counter ErrBkNoCluster *metrics.Counter - + ErrBkBodyProcess *metrics.Counter + // backend side errors ErrBkConnectBackend *metrics.Counter ErrBkRequestBackend *metrics.Counter diff --git a/bfe_server/reverseproxy.go b/bfe_server/reverseproxy.go index ace59b86c..70151b825 100644 --- a/bfe_server/reverseproxy.go +++ b/bfe_server/reverseproxy.go @@ -22,6 +22,7 @@ package bfe_server import ( "crypto/tls" + "fmt" "io" "net" "reflect" @@ -37,12 +38,14 @@ import ( bfe_cluster_backend "github.com/bfenetworks/bfe/bfe_balance/backend" "github.com/bfenetworks/bfe/bfe_balance/bal_gslb" "github.com/bfenetworks/bfe/bfe_basic" + "github.com/bfenetworks/bfe/bfe_basic/condition" "github.com/bfenetworks/bfe/bfe_config/bfe_cluster_conf/cluster_conf" "github.com/bfenetworks/bfe/bfe_debug" "github.com/bfenetworks/bfe/bfe_fcgi" "github.com/bfenetworks/bfe/bfe_http" "github.com/bfenetworks/bfe/bfe_http2" "github.com/bfenetworks/bfe/bfe_module" + "github.com/bfenetworks/bfe/bfe_modules/mod_ai_token_auth" "github.com/bfenetworks/bfe/bfe_route" "github.com/bfenetworks/bfe/bfe_route/bfe_cluster" "github.com/bfenetworks/bfe/bfe_spdy" @@ -631,6 +634,8 @@ func (p *ReverseProxy) ServeHTTP(rw bfe_http.ResponseWriter, basicReq *bfe_basic var outreq *bfe_http.Request var serverConf *bfe_route.ServerDataConf var writeTimer *time.Timer + var bf BufferFiller + var ok bool req := basicReq.HttpRequest isRedirect := false @@ -792,6 +797,46 @@ func (p *ReverseProxy) ServeHTTP(rw bfe_http.ResponseWriter, basicReq *bfe_basic // remove hop-by-hop headers hopByHopHeaderRemove(outreq, req) + if cluster.AIConf != nil { + // if cluster has AIConf, do model mapping & set api key in outreq + if cluster.AIConf.Key != nil { + mod_ai_token_auth.SetApiKey(outreq, *cluster.AIConf.Key) + } + if cluster.AIConf.ModelMapping != nil { + model, err := condition.ReqBodyJsonFetch(basicReq, "model") + if err == nil && model != "" { + newModel, ok := (*cluster.AIConf.ModelMapping)[model] + if ok { + err = condition.ReqBodyJsonSet(basicReq, "model", newModel) + if err != nil { + log.Logger.Warn("Failed to set model in request body: %s", err) + // just continue, not return error + } + } + } + } + } + // do body process before forwarding + bf, ok = outreq.Body.(BufferFiller) + if ok { + // if body is BufferFiller, call FillBuffer to process body before forwarding + for err == nil { + err = bf.FillBuffer() + } + if err != io.EOF { + basicReq.ErrCode = bfe_basic.ErrBkBodyProcess + basicReq.ErrMsg = err.Error() + + p.proxyState.ErrBkBodyProcess.Inc(1) + + // close connection + res = bfe_basic.CreateSpecifiedContentResp(basicReq, bfe_http.StatusBadRequest, "text/plain", + fmt.Sprintf("Error %s: %s", basicReq.ErrCode.Error(), basicReq.ErrMsg)) + action = closeAfterReply + goto send_response + } + } + // invoke cluster to get response res, action, err = p.clusterInvoke(srv, cluster, basicReq, rw) basicReq.HttpResponse = res @@ -968,3 +1013,7 @@ func checkBackendStatus(outlierDetectionHttpCodeStr string, statusCode int) bool } return false } + +type BufferFiller interface { + FillBuffer() error +} \ No newline at end of file diff --git a/bfe_util/redis_client/client.go b/bfe_util/redis_client/client.go new file mode 100644 index 000000000..4b5d7935b --- /dev/null +++ b/bfe_util/redis_client/client.go @@ -0,0 +1,89 @@ +// Copyright (c) 2025 The BFE Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package redis_client + +import ( + "fmt" +) + +import ( +) + +// Client: redis client interface +type Client interface { + Setex(key string, value []byte, expire int) error + Get(key string) (interface{}, error) + Expire(key string, expire int) error + Incr(key string) (int64, error) + IncrAndExpire(key string, expire int) (int64, error) + Decr(key string) (int64, error) + PIncr([]string) ([]int64, error) + GetInt64(key string) (int64, error) + IncrBy(key string, delta int64) (int64, error) +} + +// counters for module state 2 +var ( + RedisConn = "REDIS_CONN" + RedisConnFail = "REDIS_CONN_FAIL" + RedisAuthFail = "REDIS_AUTH_FAIL" + RedisExpire = "REDIS_EXPIRE" + RedisExpireFail = "REDIS_EXPIRE_FAIL" + RedisSetex = "REDIS_SETEX" + RedisSetexFail = "REDIS_SETEX_FAIL" + RedisGet = "REDIS_GET" + RedisGetFail = "REDIS_GET_FAIL" + RedisGetMiss = "REDIS_GET_MISS" + RedisGetHit = "REDIS_GET_HIT" + RedisIncr = "REDIS_INCR" + RedisIncrFail = "REDIS_INCR_FAIL" + RedisDecr = "REDIS_DECR" + RedisDecrFail = "REDIS_DECR_FAIL" + RedisSendFail = "REDIS_SEND_FAIL" + RedisFlushFail = "REDIS_FLUSH_FAIL" +) + +type Options struct { + // ServiceConf: string, bns name or a batch of bns name with weight of redis server + ServiceConf string + clusterList []RedisClusterConf + // MaxIdle: int, max idle connections in connection pool + MaxIdle int + // MaxActive: int, max active connections in connection pool + MaxActive int + // wait: bool, if wait is true and pool at the maxActive limit, + // command waits for a connection return to the pool + Wait bool + // ConnTimeoutMs: int, connect redis server timeout, in ms + ConnTimeoutMs int + // ReadTimeoutMs: int, read redis server timeout, in ms + ReadTimeoutMs int + // writeTimeoutMs: int, write redis server timeout, in ms + WriteTimeoutMs int + Password string +} + +func NewRedisClient(options *Options) Client { + return NewRedisBnsClient(options) +} + +func CheckRedisConf(redisServersStr string) error{ + _, err := ParseRedisBnsConf(redisServersStr) + if err != nil { + return fmt.Errorf("proxy mode, Redis.Bns check err: %s", err.Error()) + } + + return nil +} diff --git a/bfe_util/redis_client/redis_bns.go b/bfe_util/redis_client/redis_bns.go new file mode 100644 index 000000000..0e4ab1e42 --- /dev/null +++ b/bfe_util/redis_client/redis_bns.go @@ -0,0 +1,798 @@ +// Copyright (c) 2025 The BFE Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* +DESCRIPTION + redis client with bns support + +Usage: + bnsName := "bfe-tc.bfe.tc" // bns of redis server + maxIdle := 10 // max Idle connection + connectTimeout := 10 // connection timeout in ms + readTimeout := 10 // read redis server timeout in ms + writeTimeout := 10 // write redis server timeout in ms + redisClient := redis_bns.NewRedisClient(bnsName, maxIdle, connectTimeout, readTimeout, writeTimeout) + + // setex/get/incr/decr + redisClient.Setex("key", "val", expireTime) + redisClient.Get("key") + redisClient.Incr("key", expireTime) + redisClient.Decr("key") +*/ + +package redis_client + +import ( + "fmt" + "math/rand" + "reflect" + "strconv" + "strings" + "sync" + "time" +) + +import ( + "github.com/baidu/go-lib/log" + "github.com/baidu/go-lib/web-monitor/delay_counter" + "github.com/baidu/go-lib/web-monitor/module_state2" + "github.com/spaolacci/murmur3" + "github.com/gomodule/redigo/redis" + "github.com/bfenetworks/bfe/bfe_util/bns" +) + +var ( + // default bns update interval 10s + DfBnsUpdateInterval = 60 * time.Second + // max value for one redis cluster weight + MaxWeightValue = 100 + // min value for one redis cluster weight + MinWeightValue = 1 + // max value for all redis clusters weight sum + MaxWeightSum = 10000 + + RedisGetBnsInstanceErr = "REDIS_GET_BNS_INSTANCE_ERR" + RedisNoBnsInstance = "REDIS_NO_BNS_INSTANCE" + RedisBnsInstanceChanged = "REDIS_BNS_INSTANCE_CHANGED" +) + +type RedisClient struct { + ConnectTimeout time.Duration // connect timeout (ms) + ReadTimeout time.Duration // read timeout (ms) + WriteTimeout time.Duration // write timeout (ms) + + Password string // password, ignore if no password + + MaxIdle int // max idle conenctions in pool + MaxActive int // max active connections in pool + Wait bool // if pool meet MaxActive limit, and Wait is true, wait for a connection return to pool + + redisClusters []redisCluster // redisCluster list, offset is redisClusterId + redisClusterSlotSize uint64 // hash slot size, the value is the sum of each bns`s weight + redisClusterSlotMap []int // offset: hash slot, value: redisClusterId + + // stateDelegate StateDelegate // state delegate, this can be nil + moduleState2 *module_state2.State // state in format to module_state2 + delay *delay_counter.DelayRecent // delay counter for reids + connDelay *delay_counter.DelayRecent // delay counter for connect to redis +} + +type redisCluster struct { + bns string // bns for redis cluster + weight int // weight for current cluster + redisClient *RedisClient // associate redis client pointer + + serversLock sync.RWMutex // lock for servers + Servers []string // tcp address for redis servers + pool *redis.Pool // connection pool to redis server + poolLock sync.RWMutex // lock for pool +} + +type RedisClusterConf struct { + bns string // bns for redis cluster + weight int // weight for current cluster +} + +func ParseRedisBnsConf(serviceConfRawStr string) ([]RedisClusterConf, error) { + // 0.1 trim space in serviceConf string + serviceConf := strings.Replace(serviceConfRawStr, " ", "", -1) + + // 0.2 check empty string + if len(serviceConfRawStr) == 0 { + return []RedisClusterConf{}, fmt.Errorf("service conf is empty string") + } + + // 1. simple condition: serviceConf is just a bns + if !strings.Contains(serviceConf, ",") && !strings.Contains(serviceConf, "|") { + return []RedisClusterConf{{bns: serviceConf, weight: 1}}, nil + } + + // 2. the other condition: serviceConf is a batch of bns name with weight + // 2.1 parse and check confList from serviceConf string + confStrList := strings.Split(serviceConf, "|") + if len(confStrList) == 0 { + return []RedisClusterConf{}, fmt.Errorf("split redis serviceConf(%s) err", serviceConf) + } + clusterSize := len(confStrList) + confList := make([]RedisClusterConf, clusterSize) + for i, confStr := range confStrList { + confElements := strings.Split(confStr, ",") + if len(confElements) != 2 { + return []RedisClusterConf{}, fmt.Errorf("split redis serviceConf(%s) by ',' length err", confStr) + } + + confList[i].bns = confElements[0] + + weightElements := strings.Split(confElements[1], ":") + if len(weightElements) != 2 { + return []RedisClusterConf{}, fmt.Errorf("split redis serviceConf(%s) weightStr(%s) by ':' length err", + confStr, confElements[1]) + } + if weightElements[0] != "weight" { + return []RedisClusterConf{}, fmt.Errorf("split redis serviceConf(%s) weightStr(%s) by ':' find no 'weight'", + confStr, confElements[1]) + } + weight, err := strconv.Atoi(weightElements[1]) + if err != nil { + return []RedisClusterConf{}, fmt.Errorf("check redis serviceConf(%s) weight(%s) err(%s)", + confStr, weightElements[1], err.Error()) + } + if weight > MaxWeightValue || weight < MinWeightValue { + return []RedisClusterConf{}, fmt.Errorf("check redis serviceConf(%s) weight(%s) err, weight should be [%d, %d])", + confStr, weightElements[1], MinWeightValue, MaxWeightValue) + } + confList[i].weight = weight + } + + // 2.2 check bns name conflict and weight sum + weightSum := 0 + bnsConflictChecker := make(map[string]bool) + for _, conf := range confList { + if _, ok := bnsConflictChecker[conf.bns]; ok { + return []RedisClusterConf{}, + fmt.Errorf("check redis serviceConf(%s) err: bns(%s) conflict", serviceConf, conf.bns) + } + bnsConflictChecker[conf.bns] = true + weightSum = weightSum + conf.weight + } + if weightSum > MaxWeightSum { + return []RedisClusterConf{}, fmt.Errorf("check redis serviceConf(%s) err: weight sum overlimit(%d)", + serviceConf, MaxWeightSum) + } + + return confList, nil +} + +// NewRedisClient(): create a new redisClient with bns support +// Notice: +// - if resolve bns error, c.Servers will be empty. +// Params: +// - serviceConf: string, bns name or a batch of bns name with weight of redis server +// - maxIdle: int, max idle connections in connection pool +// - ct: int, connect redis server timeout, in ms +// - rt: int, read redis server timeout, in ms +// - wt: int, write redis server timeout, in ms +// Returns: +// - *redisClient: a new redis client +func NewRedisClient1(serviceConf string, maxIdle int, ct, rt, wt int) *RedisClient { + return NewRedisBnsClient(&Options{ + ServiceConf: serviceConf, + MaxIdle: maxIdle, + ConnTimeoutMs: ct, + ReadTimeoutMs: rt, + WriteTimeoutMs: wt, + }) +} + +// NewRedisClient2(): create a new redisClient with bns support +// Notice: +// - if resolve bns error, c.Servers will be empty. +// Params: +// - serviceConf: string, bns name or a batch of bns name with weight of redis server +// - maxIdle: int, max idle connections in connection pool +// - maxActive: int, max active connections in connection pool +// - wait: bool, if wait is true and pool at the maxActive limit, +// command waits for a connection return to the pool +// - ct: int, connect redis server timeout, in ms +// - rt: int, read redis server timeout, in ms +// - wt: int, write redis server timeout, in ms +// Returns: +// - *redisClient: a new redis client +func NewRedisClient2(serviceConf string, maxIdle, maxActive int, wait bool, ct, rt, wt int) *RedisClient { + return NewRedisBnsClient(&Options{ + ServiceConf: serviceConf, + MaxIdle: maxIdle, + ConnTimeoutMs: ct, + ReadTimeoutMs: rt, + WriteTimeoutMs: wt, + MaxActive: maxActive, + Wait: wait, + }) +} + +func (opts *Options) Format() error { + serviceConf, err := ParseRedisBnsConf(opts.ServiceConf) + if err != nil { + return fmt.Errorf("parse redis service conf %s err %s", opts.ServiceConf, err.Error()) + } + + opts.clusterList = serviceConf + return nil +} + +// NewRedisBnsClient(): create a new redisClient with bns support +// Notice: +// - if resolve bns error, c.Servers will be empty. +// Returns: +// - *redisClient: a new redis client +func NewRedisBnsClient(opts *Options) *RedisClient { + err := opts.Format() + if err != nil { + log.Logger.Warn(err.Error()) + return nil + } + + redisClusterConfList := opts.clusterList + + // create RedisClient + c := &RedisClient{ + Password: opts.Password, + + // timeout in ms + ConnectTimeout: time.Duration(opts.ConnTimeoutMs) * time.Millisecond, + ReadTimeout: time.Duration(opts.ReadTimeoutMs) * time.Millisecond, + WriteTimeout: time.Duration(opts.WriteTimeoutMs) * time.Millisecond, + + // max idle connection + MaxIdle: opts.MaxIdle, + + // max active connection + MaxActive: opts.MaxActive, + Wait: opts.Wait, + + // module state + // stateDelegate: nil, + moduleState2: nil, + delay: nil, + connDelay: nil, + } + + // create redis clusters + c.redisClusterSlotSize = 0 + c.redisClusters = make([]redisCluster, len(redisClusterConfList)) + for i, redisClusterConf := range redisClusterConfList { + c.redisClusters[i].bns = redisClusterConf.bns + c.redisClusters[i].weight = redisClusterConf.weight + c.redisClusters[i].redisClient = c + + c.redisClusterSlotSize = c.redisClusterSlotSize + uint64(redisClusterConf.weight) + + c.redisClusters[i].Servers, err = bns.NewClient().GetInstancesAddr(redisClusterConf.bns) + if err != nil { + log.Logger.Warn("get instance for %s err %s", redisClusterConf.bns, err.Error()) + } + + c.redisClusters[i].pool = &redis.Pool{ + MaxIdle: c.MaxIdle, + MaxActive: c.MaxActive, + Wait: c.Wait, + Dial: c.redisClusters[i].dial, + } + } + + // set redisClusterSlotMap + slotIndex := 0 + c.redisClusterSlotMap = make([]int, c.redisClusterSlotSize) + for id := range c.redisClusters { + for count := 0; count < c.redisClusters[id].weight; count++ { + c.redisClusterSlotMap[slotIndex] = id + slotIndex++ + } + } + + // goroutine to update bns + go c.checkServerInstance() + + return c +} + +// set state delegate to redisClient +// func (c *RedisClient) SetStateDelegate(delegate StateDelegate) { +// c.stateDelegate = delegate +// } + +// set state of module_state2 to redisClient +func (c *RedisClient) SetModuleState2(state *module_state2.State) { + c.moduleState2 = state +} + +// set delay counter to redisClient +func (c *RedisClient) SetDelay(delayCounter *delay_counter.DelayRecent) { + c.delay = delayCounter +} + +// set conn delay counter to redisClient +func (c *RedisClient) SetConnDelay(delayCounter *delay_counter.DelayRecent) { + c.connDelay = delayCounter +} + +// judge and set module_state2 by state string +func (c *RedisClient) incrModuleState2(state string) { + if c.moduleState2 != nil { + c.moduleState2.Inc(state, 1) + } +} + +// judge and set delay counter +func (c *RedisClient) setDelayState(delay *delay_counter.DelayRecent, start time.Time) { + if delay != nil { + delay.AddBySub(start, time.Now()) + } +} + +// dial choose a random server from redisCluster.Servers and connect +func (c *redisCluster) dial() (redis.Conn, error) { + c.redisClient.incrModuleState2(RedisConn) + + // choose a random server + c.serversLock.RLock() + if len(c.Servers) == 0 { + c.serversLock.RUnlock() + return nil, fmt.Errorf("no available connnection in pool") + } + server := c.Servers[rand.Intn(len(c.Servers))] + c.serversLock.RUnlock() + + // create connection to server + conn, err := redis.DialTimeout("tcp", + server, + c.redisClient.ConnectTimeout, + c.redisClient.ReadTimeout, + c.redisClient.WriteTimeout) + if err != nil { + c.redisClient.incrModuleState2(RedisConnFail) + return nil, err + } + + if password := c.redisClient.Password; password != "" { + if _, err := conn.Do("AUTH", password); err != nil { + c.redisClient.incrModuleState2(RedisAuthFail) + conn.Close() + return nil, err + } + } + + return conn, nil +} + +func (c *redisCluster) UpdateServers(servers []string) { + c.serversLock.Lock() + c.Servers = servers + c.serversLock.Unlock() +} + +func (c *redisCluster) UpdatePool(pool *redis.Pool) *redis.Pool { + c.poolLock.RLock() + oldPool := c.pool + c.pool = pool + c.poolLock.RUnlock() + + return oldPool +} + +// ActiveConnNum returns the num of active connextions +func (c *RedisClient) ActiveConnNum() int { + activeCountSum := 0 + for id := range c.redisClusters { + c.redisClusters[id].poolLock.RLock() + activeCountSum += c.redisClusters[id].pool.ActiveCount() + c.redisClusters[id].poolLock.RUnlock() + } + + return activeCountSum +} + +// Setex(): save key:value to redis server, and set expire time +// Params: +// - key: string +// - value: []byte +// - expire: int, expire time in second +// Returns: +// - nil, if success, otherwise return error +//save sessionState to session cache +func (c *RedisClient) Setex(key string, value []byte, expire int) (err error) { + c.incrModuleState2(RedisSetex) + + // get a connection + conn := c.getConnByKey(key) + defer conn.Close() + + procStart := time.Now() + // send setex cmd + conn.Send("SETEX", key, expire, value) + conn.Flush() + if _, err = conn.Receive(); err != nil { + c.incrModuleState2(RedisSetexFail) + return err + } + + c.setDelayState(c.delay, procStart) + return nil +} + +// get value from redis +func (c *RedisClient) Get(key string) (interface{}, error) { + c.incrModuleState2(RedisGet) + + // get connection from pool + conn := c.getConnByKey(key) + defer conn.Close() + + procStart := time.Now() + // get session state from redis + value, err := conn.Do("GET", key) + // redigo may return both value and err is nil + if value == nil && err == nil { + c.incrModuleState2(RedisGetMiss) + return nil, redis.ErrNil + } + // handle err is not nil + if err != nil { + if err != redis.ErrNil { + c.incrModuleState2(RedisGetFail) + } else { + c.incrModuleState2(RedisGetMiss) + } + return nil, err + } + + c.setDelayState(c.delay, procStart) + c.incrModuleState2(RedisGetHit) + return value, nil +} + +// get value from redis +func (c *RedisClient) GetInt64(key string) (int64, error) { + c.incrModuleState2(RedisGet) + + // get connection from pool + conn := c.getConnByKey(key) + defer conn.Close() + + procStart := time.Now() + // get session state from redis + value, err := redis.Int64(conn.Do("GET", key)) + // handle err is not nil + if err != nil { + if err != redis.ErrNil { + c.incrModuleState2(RedisGetFail) + } else { + c.incrModuleState2(RedisGetMiss) + } + return 0, err + } + + c.setDelayState(c.delay, procStart) + c.incrModuleState2(RedisGetHit) + return value, nil +} + +// set expire to redis +func (c *RedisClient) Expire(key string, expire int) error { + c.incrModuleState2(RedisExpire) + + // get connection from pool + conn := c.getConnByKey(key) + defer conn.Close() + + procStart := time.Now() + // get session state from redis + _, err := conn.Do("EXPIRE", key, expire) + if err != nil { + c.incrModuleState2(RedisExpireFail) + return err + } + + c.setDelayState(c.delay, procStart) + return nil +} + +// incr key to redis +func (c *RedisClient) Incr(key string) (int64, error) { + c.incrModuleState2(RedisIncr) + + // get connection from pool + conn := c.getConnByKey(key) + defer conn.Close() + + procStart := time.Now() + // send incr & expire cmd + conn.Send("INCR", key) + conn.Flush() + // get result from incr cmd + count, err := redis.Int64(conn.Receive()) + if err != nil { + c.incrModuleState2(RedisIncrFail) + return count, err + } + + c.setDelayState(c.delay, procStart) + return count, nil +} + +// incr key to redis +func (c *RedisClient) IncrBy(key string, delta int64) (int64, error) { + c.incrModuleState2(RedisIncr) + + // get connection from pool + conn := c.getConnByKey(key) + defer conn.Close() + + procStart := time.Now() + // send incr & expire cmd + conn.Send("INCRBY", key, delta) + conn.Flush() + // get result from incr cmd + count, err := redis.Int64(conn.Receive()) + if err != nil { + c.incrModuleState2(RedisIncrFail) + return count, err + } + + c.setDelayState(c.delay, procStart) + return count, nil +} + +// incr and expire key to redis +func (c *RedisClient) IncrAndExpire(key string, expire int) (int64, error) { + c.incrModuleState2(RedisIncr) + + // get connection from pool + conn := c.getConnByKey(key) + defer conn.Close() + + procStart := time.Now() + // send incr & expire cmd + conn.Send("INCR", key) + conn.Send("EXPIRE", key, expire) + conn.Flush() + // get result from incr cmd + count, err := redis.Int64(conn.Receive()) + if err != nil { + c.incrModuleState2(RedisIncrFail) + return count, err + } + + // get result from expire cmd + if _, err = conn.Receive(); err != nil { + c.incrModuleState2(RedisExpireFail) + return count, err + } + + c.setDelayState(c.delay, procStart) + return count, nil +} + +/* +do redis pipeline incr command, filter the keys by redis cluster id +set countList and errList as return value, only modify the members which belong to current cluster id +param: + keyList []string total key list, this function only use the members which belog to current cluster id + countList *[]int64 count list return value, only modify the members which belong to current cluster id + errList *[]error error return value, only modify the member with the offset is current cluster id +*/ +func (c *RedisClient) pincrByRedisClusterId(keyList []string, + clusterId int, + countList *[]int64, + errList *[]error) { + var err error + var count int64 + + // get a sub list for the keys belong to current cluster id + subKeyList := make([]string, 0) + for i := 0; i < len(keyList); i++ { + if c.getClusterIdByKey(keyList[i]) == clusterId { + subKeyList = append(subKeyList, keyList[i]) + } + } + + // if there is no sub keylist for current cluster id, just return + if len(subKeyList) == 0 { + return + } + + // get connection from pool + conn := c.getConnByClusterId(clusterId) + defer conn.Close() + + // send by pipeline + subCountList := make([]int64, len(subKeyList)) + for i := range subKeyList { + c.incrModuleState2(RedisIncr) + + // send incr cmd + if err = conn.Send("INCR", subKeyList[i]); err != nil { + c.incrModuleState2(RedisSendFail) + goto ret + } + } + + // flush + if err = conn.Flush(); err != nil { + c.incrModuleState2(RedisFlushFail) + goto ret + } + + // receive values + for i := range subKeyList { + // get result from incr cmd + if count, err = redis.Int64(conn.Receive()); err != nil { + c.incrModuleState2(RedisIncrFail) + goto ret + } + + // append to countList + subCountList[i] = count + } + +ret: + if err == nil { + subIndex := 0 + for i := 0; i < len(keyList); i++ { + if c.getClusterIdByKey(keyList[i]) == clusterId { + (*countList)[i] = subCountList[subIndex] + subIndex++ + } + } + } + (*errList)[clusterId] = err +} + +// PIncr incr keys in pipeline mode, seprate keyList by clusterId and do pincr concurrently +func (c *RedisClient) PIncr(keyList []string) ([]int64, error) { + var err error + var totalErrStr string + errList := make([]error, len(c.redisClusters)) + procStart := time.Now() + + if len(keyList) == 0 { + return []int64{}, fmt.Errorf("len err: keyList(%d)", len(keyList)) + } + countList := make([]int64, len(keyList), len(keyList)) + + // run pincr seprate by cluseter id concurrently + for id := range c.redisClusters { + c.pincrByRedisClusterId(keyList, + id, + &countList, + &errList) + } + + // wait for each response + for redisClusterId := 0; redisClusterId < len(c.redisClusters); redisClusterId++ { + if errList[redisClusterId] != nil { + totalErrStr += fmt.Sprintf("redisClusterId(%d) pincr err(%s), ", + redisClusterId, errList[redisClusterId].Error()) + } + } + + // hanele error response + if totalErrStr != "" { + err = fmt.Errorf(totalErrStr) + } + + c.setDelayState(c.delay, procStart) + return countList, err +} + +// decr key to redis +func (c *RedisClient) Decr(key string) (int64, error) { + c.incrModuleState2(RedisDecr) + + // get connection from pool + conn := c.getConnByKey(key) + defer conn.Close() + + procStart := time.Now() + // send decr cmd + conn.Send("DECR", key) + conn.Flush() + // get result from decr cmd + count, err := redis.Int64(conn.Receive()) + if err != nil { + c.incrModuleState2(RedisDecrFail) + return count, err + } + + c.setDelayState(c.delay, procStart) + return count, nil +} + +// get a connection from connection pool by redis key +// todo: change this to private function +func (c *RedisClient) getConnByKey(key string) redis.Conn { + return c.getConnByClusterId(c.getClusterIdByKey(key)) +} + +// get redis cluster id by redis key +func (c *RedisClient) getClusterIdByKey(key string) int { + slot := getHash([]byte(key), c.redisClusterSlotSize) + return c.redisClusterSlotMap[slot] +} + +// get a connection from connection pool by cluster id +func (c *RedisClient) getConnByClusterId(clusterId int) redis.Conn { + procStart := time.Now() + + // get connection pool + c.redisClusters[clusterId].poolLock.RLock() + pool := c.redisClusters[clusterId].pool + c.redisClusters[clusterId].poolLock.RUnlock() + + // get connection from pool + conn := pool.Get() + + c.setDelayState(c.connDelay, procStart) + return conn +} + +// update bns +func (c *RedisClient) checkServerInstance() { + for { + time.Sleep(DfBnsUpdateInterval) + + for id := range c.redisClusters { + // check addresses of redis servers + servers, err := bns.NewClient().GetInstancesAddr(c.redisClusters[id].bns) + if err != nil { + c.incrModuleState2(RedisGetBnsInstanceErr) + continue + } + if len(servers) == 0 { + c.incrModuleState2(RedisNoBnsInstance) + continue + } + if reflect.DeepEqual(servers, c.redisClusters[id].Servers) { + continue + } + + // update addresses of redis servers + c.redisClusters[id].UpdateServers(servers) + + // counter bns instance changed + c.incrModuleState2(RedisBnsInstanceChanged) + + // update connection pool + pool := &redis.Pool{ + MaxIdle: c.MaxIdle, + MaxActive: c.MaxActive, + Wait: c.Wait, + Dial: c.redisClusters[id].dial, + } + oldPool := c.redisClusters[id].UpdatePool(pool) + oldPool.Close() + } + } +} + +func getHash(value []byte, base uint64) int { + var hash uint64 + + if value == nil { + hash = uint64(rand.Uint32()) + } else { + hash = murmur3.Sum64(value) + } + + return int(hash % base) +} diff --git a/conf/bfe.conf b/conf/bfe.conf index 8a6fd8620..935a9f5f8 100644 --- a/conf/bfe.conf +++ b/conf/bfe.conf @@ -64,6 +64,8 @@ Modules = mod_prison Modules = mod_wasm Modules = mod_unified_waf +Modules = mod_ai_token_auth +Modules = mod_body_process # interval for get diff of proxy-state MonitorInterval = 20 diff --git a/conf/mod_ai_token_auth/mod_ai_token_auth.conf b/conf/mod_ai_token_auth/mod_ai_token_auth.conf new file mode 100644 index 000000000..3eb7be626 --- /dev/null +++ b/conf/mod_ai_token_auth/mod_ai_token_auth.conf @@ -0,0 +1,19 @@ +[basic] +ProductRulePath = mod_ai_token_auth/token_rule.data + +[redis] +# bns addr +#bns = BLB.ALB-redis +bns = BFE.poc-redis-wx + +# timeout in ms +connectTimeout = 20 +readTimeout = 20 +writeTimeout = 20 + +# max idle connections +maxIdle = 20 + +[log] +OpenDebug = false + diff --git a/conf/mod_ai_token_auth/token_rule.data b/conf/mod_ai_token_auth/token_rule.data new file mode 100644 index 000000000..2a332a310 --- /dev/null +++ b/conf/mod_ai_token_auth/token_rule.data @@ -0,0 +1,19 @@ +{ + "Config": { + "example_product" :[ + ] + }, + "Tokens": { + "example_product": { + "TESTKEY": { + "key": "TESTKEY", + "status": 1, + "name": "test", + "expired_time": -1, + "unlimited_quota": true + } + } + }, + "Version": "0" +} + diff --git a/conf/mod_body_process/body_process_rule.data b/conf/mod_body_process/body_process_rule.data new file mode 100644 index 000000000..c86e8f51e --- /dev/null +++ b/conf/mod_body_process/body_process_rule.data @@ -0,0 +1,8 @@ +{ + "Config": { + "example_product": [ + ] + }, + "Version": "0" +} + diff --git a/conf/mod_body_process/mod_body_process.conf b/conf/mod_body_process/mod_body_process.conf new file mode 100644 index 000000000..b89fa4ab8 --- /dev/null +++ b/conf/mod_body_process/mod_body_process.conf @@ -0,0 +1,6 @@ +[basic] +ProductRulePath = mod_body_process/body_process_rule.data + +[log] +OpenDebug = false + diff --git a/conf/server_data_conf/name_conf.data b/conf/server_data_conf/name_conf.data index df3cc8bc9..2e2271e0e 100644 --- a/conf/server_data_conf/name_conf.data +++ b/conf/server_data_conf/name_conf.data @@ -1,6 +1,11 @@ { "Version": "init version", "Config": { + "BFE.poc-redis-wx": [{ + "Host": "172.18.1.244", + "Port": 6379, + "Weight": 10 + }], "example.redis.cluster": [ { "Host": "192.168.1.1", diff --git a/docs/en_us/condition/condition_primitive_index.md b/docs/en_us/condition/condition_primitive_index.md index da5af7f7c..bb21de410 100644 --- a/docs/en_us/condition/condition_primitive_index.md +++ b/docs/en_us/condition/condition_primitive_index.md @@ -76,6 +76,10 @@ * [req_vip_in(vip_list)](./request/ip.md#req_vip_invip_list) * [req_vip_range(start_ip, end_ip)](./request/ip.md#req_vip_rangestart_ip-end_ip) +### body + + * [req_body_json_in(json_path, value_list, case_insensitive)]() + ## Response Primitive ### code diff --git a/docs/en_us/condition/request/body.md b/docs/en_us/condition/request/body.md new file mode 100644 index 000000000..53f8970b2 --- /dev/null +++ b/docs/en_us/condition/request/body.md @@ -0,0 +1,18 @@ +# Condition Primitives Related to Request Body + +## req_body_json_in(json_path, value_list, case_insensitive) + +* Meaning: Searches for the field specified by `json_path` in the JSON-formatted request body and checks if its value exactly matches any in `value_list`. +* Parameters + +| Parameter | Description | +| ---------------- | ---------------------------------------------- | +| json_path | String
The path to the JSON field in the request body | +| value_list | String
List of values, separated by ‘|’ | +| case_insensitive | Boolean
Whether to ignore case sensitivity | + +* Example + +```go +req_body_json_in("model", "deepseek-r1|qwen-plus", true) +``` diff --git a/docs/en_us/configuration/server_data_conf/cluster_conf.data.md b/docs/en_us/configuration/server_data_conf/cluster_conf.data.md index ea771cba0..657feea9c 100644 --- a/docs/en_us/configuration/server_data_conf/cluster_conf.data.md +++ b/docs/en_us/configuration/server_data_conf/cluster_conf.data.md @@ -82,6 +82,13 @@ Note: The following configuration items are located in the namespace `Config[v]` | HTTPSConf.RSCAList | []String
Required when BackendConf.Protocol is https and server certificate verification is needed (i.e., RSInsecureSkipVerify is false). If not filled, the system default CA pool is used. List items are certificate file paths. Certificate files must be in x509 standard PEM format. Multiple CA certificates in the CA trust chain can be combined into one PEM file. | | HTTPSConf.RSInsecureSkipVerify | Boolean
Server certificate verification switch
true: Do not verify, false: Verify (default) | +#### AI Service Configuration + +| Configuration Item | Description | +| ------------------------------- | ------------------------------------------------------------------------------------------------------------- | +| AIConf.Key | String
API-Key for the backend large model service
If empty, the API-Key is not reset when accessing the backend service and the request's API-Key is retained | +| ModelMapping | Map\[string\]string
Mapping from original request model to backend service model. When accessing the backend service, the model field in the request will be looked up in this mapping; if matched, the model field in the request will be overwritten | + ## Configuration Example ```json diff --git a/docs/en_us/modules/mod_ai_token_auth/mod_ai_token_auth.md b/docs/en_us/modules/mod_ai_token_auth/mod_ai_token_auth.md new file mode 100644 index 000000000..9b350f5e8 --- /dev/null +++ b/docs/en_us/modules/mod_ai_token_auth/mod_ai_token_auth.md @@ -0,0 +1,107 @@ +# mod_ai_token_auth + +## Module Overview + +mod_ai_token_auth supports API-key (token) authentication for LLM services. An API-key represents a token with certain access permissions and quotas for specific LLM services. This module checks the API-key carried in the request according to rules to determine whether the request is allowed to access the LLM service. + +Request header carries the API-key: +``` +Authorization: Bearer +``` + +## Basic Configuration + +### Configuration Description + +Module configuration file: conf/mod_ai_token_auth/mod_ai_token_auth.conf + +| Option | Description | +| ------------------- | ------------------------------------------------- | +| Basic.ProductRulePath | String
File path for API-key declaration and rule configuration | +| redis.bns | String
BNS name of the Redis service. Redis is used to store API-key quota usage. | +| Log.OpenDebug | Boolean
Enable debug logs
Default: False | + +### Configuration Example + +```ini +[basic] +ProductRulePath = mod_ai_token_auth/token_rule.data + +[redis] +# bns addr +bns = BLB.ALB-redis + +# timeout in ms +connectTimeout = 20 +readTimeout = 20 +writeTimeout = 20 + +# max idle connections +maxIdle = 20 + +[log] +OpenDebug = false +``` + +## Rule Configuration + +### Configuration Description + +| Option | Description | +| --------------------- | ------------------------------------------------- | +| Version | String
Configuration file version | +| Tokens | Object
API-key declarations for all product lines | +| Tokens{k} | String
Product line name| +| Tokens{v} | Object
All API-keys under a product line | +| Tokens{v}{k} | String
An API-key | +| Tokens{v}{v} | Object
An API-key declaration, data structure below. | +| Config | Object
API-key authentication rule configuration for all product lines | +| Config{k} | String
Product line name| +| Config{v} | Array
API-key authentication rule list under a product line | +| Config{v}[] | Object
API-key authentication rule | +| Config{v}[].Cond | String
Matching condition, syntax details in [Condition](../../condition/condition_grammar.md) | +| Config{v}[].Action | Object
Action. Only one action is supported: { "cmd": "CHECK_TOKEN" } | + +API-key declaration data structure: +``` +struct { + Key string // API-key + Status int // API-key status: 1 - Enabled; 2 - Disabled; 3 - Expired; 4 - Exhausted + Name string // Name + UpdateTime int64 // Update time (Unix Time). Change means a new quota consumption cycle starts, recalculating UsedQuota. + ExpiredTime int64 // Expiry time (Unix Time). -1 means never expires + RemainQuota int64 // Total available quota (unit: token) + UnlimitedQuota bool // Unlimited quota or not + Models *string // Allowed model list, multiple model names separated by commas + Subnet *string // Allowed source IP subnet +} +``` + +### Configuration Example + +```json +{ + "Config": { + "example_product" :[ + { + "cond": "default_t()", + "action": { + "cmd": "CHECK_TOKEN" + } + } + ] + }, + "Tokens": { + "example_product": { + "TESTKEY": { + "key": "TESTKEY", + "status": 1, + "name": "test", + "expired_time": -1, + "unlimited_quota": true + } + } + }, + "Version": "20190101000000" +} +``` diff --git a/docs/en_us/modules/mod_body_process/mod_body_process.md b/docs/en_us/modules/mod_body_process/mod_body_process.md new file mode 100644 index 000000000..2be93c1d6 --- /dev/null +++ b/docs/en_us/modules/mod_body_process/mod_body_process.md @@ -0,0 +1,96 @@ +# mod_body_process +## Module Overview + +mod_body_process provides a streaming body processing framework. In many scenarios, the request or response body is streamed, such as SSE. For streaming data, if processing is required (e.g., content review), it cannot be cached and processed as a whole; instead, it must be processed chunk by chunk in real time as it is received. + +Within this streaming processing framework, body data goes through three steps: +* decoder - Parses received data into events in real time +* processors - A sequence of event processors. Each processor transforms input events into output events and can terminate the process by reporting errors. Events generated by the decoder are processed sequentially by the event processor chain +* encoder - Re-encodes events into body data + +Users can customize the processing flow for request or response bodies via configuration rules. Supported components include: +### decoder +* line - Parses data line by line, each line as an event +* json - Parses JSON objects from data, each JSON object as an event +* sse - Parses SSE events from data +* default - Automatically selects decoder based on contentType +### processor +* textfilter - Calls the ToolGood.TextFilter service for content review +### encoder +* default - Directly calls the event's ToBytes() function to generate body data + +## Basic Configuration + +### Configuration Description + +Module config file: conf/mod_body_process/mod_body_process.conf + +| Option | Description | +| ------------------- | ------------------------------------------------- | +| Basic.ProductRulePath | String
Path to the rule config file | +| Log.OpenDebug | Boolean
Enable debug logs
Default: False | + +### Configuration Example + +```ini +[basic] +ProductRulePath = ../data/mod_body_process/body_process_rule.data + +[log] +OpenDebug = false +``` + +## Rule Configuration + +### Configuration Description + +| Option | Description | +| --------------------- | ------------------------------------------------- | +| Version | String
Config file version | +| Config | Object
API-key authentication rule config for all product lines | +| Config{k} | String
Product line name| +| Config{v} | Array
API-key authentication rule list for the product line | +| Config{v}[] | Object
API-key authentication rule | +| Config{v}[].Cond | String
Matching condition, syntax see [Condition](../../condition/condition_grammar.md) | +| Config{v}[].RequestProcess | Object
Request body processing flow config, see structure below | +| Config{v}[].ResponseProcess | Object
Response body processing flow config, see structure below | + +Data structure for body processing flow config: +``` +// Processing flow +struct { + Dec string // decoder, uses default if not specified + Enc string // encoder, uses default if not specified + Proc []ProcConf // processor list +} +// ProcConf +struct { + Name string // processor name, currently only supports "textfilter" + Params []string // processor parameter list. textfilter: Params[0] - ToolGood.TextFilter service URL +} +``` + +### Configuration Example + +```json +{ + "Config": { + "example_product": [ + { + "Cond": "!req_body_json_in(\"model\", \"\", false)", + "RequestProcess": { + "Proc": [ + {"name":"textfilter", "params":["http://172.19.1.136:9191/api/"]} + ] + }, + "ResponseProcess": { + "Proc": [ + {"name":"textfilter", "params":["http://172.19.1.136:9191/api/"]} + ] + } + } + ] + }, + "Version": "20190101000000" +} +``` diff --git a/docs/zh_cn/condition/condition_primitive_index.md b/docs/zh_cn/condition/condition_primitive_index.md index 7b5a4ba85..e427997a6 100644 --- a/docs/zh_cn/condition/condition_primitive_index.md +++ b/docs/zh_cn/condition/condition_primitive_index.md @@ -76,6 +76,10 @@ * [req_vip_in(vip_list)](./request/ip.md#req_vip_invip_list) * [req_vip_range(start_ip, end_ip)](./request/ip.md#req_vip_rangestart_ip-end_ip) +### body + + * [req_body_json_in(json_path, value_list, case_insensitive)]() + ## 响应相关 ### code diff --git a/docs/zh_cn/condition/request/body.md b/docs/zh_cn/condition/request/body.md new file mode 100644 index 000000000..6a78303da --- /dev/null +++ b/docs/zh_cn/condition/request/body.md @@ -0,0 +1,18 @@ +# 请求body相关条件原语 + +## req_body_json_in(json_path, value_list, case_insensitive) + +* 含义: 在json格式的请求body中,查找json_path指定的字段,判断其值是否精确匹配value_list之一 +* 参数 + +| 参数 | 描述 | +| -------- | ---------------------- | +| json_path | String
请求body中的json字段的路径 | +| value_list | String
value列表,多个之间使用‘|’连接 | +| case_insensitive | Boolean
是否忽略大小写 | + +* 示例 + +```go +req_body_json_in("model", "deepseek-r1|qwen-plus", true) +``` diff --git a/docs/zh_cn/configuration/server_data_conf/cluster_conf.data.md b/docs/zh_cn/configuration/server_data_conf/cluster_conf.data.md index 9bb7d851a..f6b759177 100644 --- a/docs/zh_cn/configuration/server_data_conf/cluster_conf.data.md +++ b/docs/zh_cn/configuration/server_data_conf/cluster_conf.data.md @@ -83,6 +83,13 @@ cluster_conf.data为集群转发配置文件。 | HTTPSConf.RSCAList | []String
BackendConf.Protocol为https,并且需要验证服务端的证书(即RSInsecureSkipVerify为false)时必填,如果不填则使用系统默认CA池。列表项为证书文件路径,证书文件必须是符合x509标准的pem格式证书,允许将CA信任链中的多个CA证书合入一个pem文件中。| | HTTPSConf.RSInsecureSkipVerify | Boolean
服务端证书验证开关
true:不验证,false:验证(默认)| +#### AI服务配置 + +| 配置项 | 描述 | +| --------------------------------- | ------------------------------------------------------------ | +| AIConf.Key | String
后端大模型服务的API-Key
空 - 访问后端服务时不重置API-Key,仍保持请求的API-Key | +| ModelMapping | Map\[string\]string
原请求model -> 后端服务的model 的映射关系。访问后端服务时将根据请求的 model 字段查找此映射关系,命中的话则重写请求的 model 字段| + ## 配置示例 ```json diff --git a/docs/zh_cn/modules/mod_ai_token_auth/mod_ai_token_auth.md b/docs/zh_cn/modules/mod_ai_token_auth/mod_ai_token_auth.md new file mode 100644 index 000000000..dc29d1925 --- /dev/null +++ b/docs/zh_cn/modules/mod_ai_token_auth/mod_ai_token_auth.md @@ -0,0 +1,107 @@ +# mod_ai_token_auth + +## 模块简介 + +mod_ai_token_auth 支持大模型 api-key(token) 鉴权。一个 api-key 代表一个对某些大模型服务拥有一定访问权限和配额的令牌。在此模块中根据规则对请求中携带的 api-key 进行检查,决定该请求是否允许访问大模型服务。 + +请求 header 携带 api-key: +``` +Authorization: Bearer +``` + +## 基础配置 + +### 配置描述 + +模块配置文件: conf/mod_ai_token_auth/mod_ai_token_auth.conf + +| 配置项 | 描述 | +| ------------------- | ------------------------------------------- | +| Basic.ProductRulePath | String
api-key声明和规则配置的文件路径 | +| redis.bns | String
redis服务的bns名。redis用于存储api-key的配额使用量。 | +| Log.OpenDebug | Boolean
是否开启 debug 日志
默认值False | + +### 配置示例 + +```ini +[basic] +ProductRulePath = mod_ai_token_auth/token_rule.data + +[redis] +# bns addr +bns = BLB.ALB-redis + +# timeout in ms +connectTimeout = 20 +readTimeout = 20 +writeTimeout = 20 + +# max idle connections +maxIdle = 20 + +[log] +OpenDebug = false +``` + +## 规则配置 + +### 配置描述 + +| 配置项 | 描述 | +| ---------------------| ------------------------------------------- | +| Version | String
配置文件版本 | +| Tokens | Object
所有产品线的 api-key 声明 | +| Tokens{k} | String
产品线名称| +| Tokens{v} | Object
产品线下的所以 api-key | +| Tokens{v}{k} | String
一个 api-key | +| Tokens{v}{v} | Object
一个 api-key 声明,数据结构见下。 | +| Config | Object
所有产品线的 api-key 鉴权规则配置 | +| Config{k} | String
产品线名称| +| Config{v} | Array
产品线下 api-key 鉴权规则列表 | +| Config{v}[] | Object
api-key 鉴权规则 | +| Config{v}[].Cond | String
匹配条件, 语法详见[Condition](../../condition/condition_grammar.md) | +| Config{v}[].Action | Object
动作。只支持一种动作:{ "cmd": "CHECK_TOKEN" } | + +api-key 声明的数据结构: +``` +struct { + Key string // api-key + Status int // api-key的状态:1 - Enabled; 2 - Disabled; 3 - Expired; 4 - Exhausted + Name string // 名字 + UpdateTime int64 // 更新时间 (Unix Time)。改变意味着开启一个新的配额消费周期,重新开始计算UsedQuota。 + ExpiredTime int64 // 过期时间 (Unix Time)。 -1 - 永不过期 + RemainQuota int64 // 总可用配额 (单位: token) + UnlimitedQuota bool // 是否无限配额 + Models *string // 允许的模型列表,多个模型名由逗号分开 + Subnet *string // 允许的源ip子网 +} +``` + +### 配置示例 + +```json +{ + "Config": { + "example_product" :[ + { + "cond": "default_t()", + "action": { + "cmd": "CHECK_TOKEN" + } + } + ] + }, + "Tokens": { + "example_product": { + "TESTKEY": { + "key": "TESTKEY", + "status": 1, + "name": "test", + "expired_time": -1, + "unlimited_quota": true + } + } + }, + Version": "20190101000000" +} +``` diff --git a/docs/zh_cn/modules/mod_body_process/mod_body_process.md b/docs/zh_cn/modules/mod_body_process/mod_body_process.md new file mode 100644 index 000000000..dca78f375 --- /dev/null +++ b/docs/zh_cn/modules/mod_body_process/mod_body_process.md @@ -0,0 +1,97 @@ +# mod_body_process + +## 模块简介 + +mod_body_process 提供了一个 body 的流式处理框架。在许多场景中,请求或应答的body是流式的,例如SSE。对于流式的数据,如果需要做某种处理的话,例如内容审查,我们不能整体缓存下来再处理,而只能是一边接收一边处理,实时地一块一块地处理。 + +在这个流式处理框架中,body数据将依次经过三个步骤: +* decoder - 实时地将已接收到的数据解析为事件 +* processors - 事件处理器序列。每个处理器都是将输入事件转化为输出事件,也可以报错从而终止处理流程。由decoder产生的事件将依次经过事件处理器序列的处理 +* encoder - 将事件重新编码为body数据 + +用户可能通过配置规则定制请求或应答body的处理流程。目前支持的各种组件: +### decoder +* line - 将数据按行解析,每一行作为一个事件 +* json - 从数据中解析json对象,每个json对象作为一个事件 +* sse - 从数据中解析 sse 事件 +* 缺省 - 根据contentType自适应选择decoder +### processor +* textfilter - 调用 ToolGood.TextFilter 服务,对内容进行审查 +### encoder +* 缺省 - 直接调用事件的 ToBytes() 函数生成 body 数据 + +## 基础配置 + +### 配置描述 + +模块配置文件: conf/mod_body_process/mod_body_process.conf + +| 配置项 | 描述 | +| ------------------- | ------------------------------------------- | +| Basic.ProductRulePath | String
规则配置的文件路径 | +| Log.OpenDebug | Boolean
是否开启 debug 日志
默认值False | + +### 配置示例 + +```ini +[basic] +ProductRulePath = ../data/mod_body_process/body_process_rule.data + +[log] +OpenDebug = false +``` + +## 规则配置 + +### 配置描述 + +| 配置项 | 描述 | +| ---------------------| ------------------------------------------- | +| Version | String
配置文件版本 | +| Config | Object
所有产品线的 api-key 鉴权规则配置 | +| Config{k} | String
产品线名称| +| Config{v} | Array
产品线下 api-key 鉴权规则列表 | +| Config{v}[] | Object
api-key 鉴权规则 | +| Config{v}[].Cond | String
匹配条件, 语法详见[Condition](../../condition/condition_grammar.md) | +| Config{v}[].RequestProcess | Object
请求body的处理流程配置,数据结构见下 | +| Config{v}[].ResponseProcess | Object
应答body的处理流程配置,数据结构见下 | + +body的处理流程配置的数据结构: +``` +// 处理流程 +struct { + Dec string // decoder,不配置则使用缺省decoder + Enc string // encoder,不配置则使用缺省encoder + Proc []ProcConf // 处理器列表 +} +// ProcConf +struct { + Name string // 处理器名。目前只支持 “textfilter” + Params []string // 处理器的参数表。textfilter: Params[0] - ToolGood.TextFilter 服务的URL +} +``` + +### 配置示例 + +```json +{ + "Config": { + "example_product": [ + { + "Cond": "!req_body_json_in(\"model\", \"\", false)", + "RequestProcess": { + "Proc": [ + {"name":"textfilter", "params":["http://172.19.1.136:9191/api/"]} + ] + }, + "ResponseProcess": { + "Proc": [ + {"name":"textfilter", "params":["http://172.19.1.136:9191/api/"]} + ] + } + } + ] + }, + Version": "20190101000000" +} +``` diff --git a/go.mod b/go.mod index 8c9e57825..69d146ec6 100644 --- a/go.mod +++ b/go.mod @@ -42,6 +42,14 @@ require ( require ( github.com/bfenetworks/proxy-wasm-go-host v0.0.0-20241202144118-62704e5df808 github.com/go-jose/go-jose/v4 v4.0.5 + github.com/google/uuid v1.3.0 + github.com/tidwall/gjson v1.18.0 + github.com/tidwall/sjson v1.2.5 +) + +require ( + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect ) require ( diff --git a/go.sum b/go.sum index 9dda25134..da085ac17 100644 --- a/go.sum +++ b/go.sum @@ -57,6 +57,8 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= @@ -139,6 +141,15 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tetratelabs/wazero v1.2.1 h1:J4X2hrGzJvt+wqltuvcSjHQ7ujQxA9gb6PeMs4qlUWs= github.com/tetratelabs/wazero v1.2.1/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tjfoc/gmsm v1.3.2 h1:7JVkAn5bvUJ7HtU08iW6UiD+UTmJTIToHCfeFzkcCxM= github.com/tjfoc/gmsm v1.3.2/go.mod h1:HaUcFuY0auTiaHB9MHFGCPx5IaLhTUd2atbCFBQXn9w= github.com/uber/jaeger-client-go v2.25.0+incompatible h1:IxcNZ7WRY1Y3G4poYlx24szfsn/3LvK9QHCq9oQw8+U=