diff --git a/common/types/response.go b/common/types/response.go index d616aa8013..714c0dcb05 100644 --- a/common/types/response.go +++ b/common/types/response.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/gin-gonic/gin" + "github.com/mitchellh/mapstructure" ) // Response the response schema @@ -13,6 +14,19 @@ type Response struct { Data interface{} `json:"data"` } +func (resp *Response) DecodeData(out interface{}) error { + // Decode generically unmarshaled JSON (map[string]any, []any) into a typed struct + // honoring `json` tags and allowing weak type conversions. + dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + TagName: "json", + Result: out, + }) + if err != nil { + return err + } + return dec.Decode(resp.Data) +} + // RenderJSON renders response with json func RenderJSON(ctx *gin.Context, errCode int, err error, data interface{}) { var errMsg string diff --git a/coordinator/conf/config_proxy.json b/coordinator/conf/config_proxy.json new file mode 100644 index 0000000000..886c10bf51 --- /dev/null +++ b/coordinator/conf/config_proxy.json @@ -0,0 +1,34 @@ +{ + "proxy_manager": { + "proxy_cli": { + "proxy_name": "proxy_name", + "secret": "client private key" + }, + "auth": { + "secret": "proxy secret key", + "challenge_expire_duration_sec": 3600, + "login_expire_duration_sec": 3600 + }, + "verifier": { + "min_prover_version": "v4.4.45", + "verifiers": [ + { + "assets_path": "assets", + "fork_name": "euclidV2" + }, + { + "assets_path": "assets", + "fork_name": "feynman" + } + ] + } + }, + "coordinators": { + "sepolia": { + "base_url": "http://localhost:8555", + "retry_count": 10, + "retry_wait_time_sec": 10, + "connection_timeout_sec": 30 + } + } +} diff --git a/coordinator/internal/config/proxy_config.go b/coordinator/internal/config/proxy_config.go new file mode 100644 index 0000000000..1ea8bd295c --- /dev/null +++ b/coordinator/internal/config/proxy_config.go @@ -0,0 +1,71 @@ +package config + +import ( + "encoding/json" + "os" + "path/filepath" + + "scroll-tech/common/utils" +) + +// Proxy loads proxy configuration items. +type ProxyManager struct { + // Zk verifier config help to confine the connected prover. + Verifier *VerifierConfig `json:"verifier"` + Client *ProxyClient `json:"proxy_cli"` + Auth *Auth `json:"auth"` +} + +func (m *ProxyManager) Normalize() { + if m.Client.Secret == "" { + m.Client.Secret = m.Auth.Secret + } + + if m.Client.ProxyVersion == "" { + m.Client.ProxyVersion = m.Verifier.MinProverVersion + } +} + +// Proxy client configuration for connect to upstream as a client +type ProxyClient struct { + ProxyName string `json:"proxy_name"` + ProxyVersion string `json:"proxy_version,omitempty"` + Secret string `json:"secret,omitempty"` +} + +// Coordinator configuration +type UpStream struct { + BaseUrl string `json:"base_url"` + RetryCount uint `json:"retry_count"` + RetryWaitTime uint `json:"retry_wait_time_sec"` + ConnectionTimeoutSec uint `json:"connection_timeout_sec"` +} + +// Config load configuration items. +type ProxyConfig struct { + ProxyManager *ProxyManager `json:"proxy_manager"` + ProxyName string `json:"proxy_name"` + Coordinators map[string]*UpStream `json:"coondiators"` +} + +// NewConfig returns a new instance of Config. +func NewProxyConfig(file string) (*ProxyConfig, error) { + buf, err := os.ReadFile(filepath.Clean(file)) + if err != nil { + return nil, err + } + + cfg := &ProxyConfig{} + err = json.Unmarshal(buf, cfg) + if err != nil { + return nil, err + } + + // Override config with environment variables + err = utils.OverrideConfigWithEnv(cfg, "SCROLL_COORDINATOR_PROXY") + if err != nil { + return nil, err + } + + return cfg, nil +} diff --git a/coordinator/internal/controller/api/auth.go b/coordinator/internal/controller/api/auth.go index b38f0d827a..c2abb86f09 100644 --- a/coordinator/internal/controller/api/auth.go +++ b/coordinator/internal/controller/api/auth.go @@ -19,28 +19,56 @@ type AuthController struct { loginLogic *auth.LoginLogic } +func NewAuthControllerWithLogic(loginLogic *auth.LoginLogic) *AuthController { + return &AuthController{ + loginLogic: loginLogic, + } +} + // NewAuthController returns an LoginController instance func NewAuthController(db *gorm.DB, cfg *config.Config, vf *verifier.Verifier) *AuthController { return &AuthController{ - loginLogic: auth.NewLoginLogic(db, cfg, vf), + loginLogic: auth.NewLoginLogic(db, cfg.ProverManager.Verifier, vf), } } -// Login the api controller for login +// Login the api controller for login, used as the Authenticator in JWT +// It can work in two mode: full process for normal login, or if login request +// is posted from proxy, run a simpler process to login a client func (a *AuthController) Login(c *gin.Context) (interface{}, error) { + + // check if the login is post by proxy + var viaProxy bool + if proverType, proverTypeExist := c.Get(types.ProverProviderTypeKey); proverTypeExist { + proverType := uint8(proverType.(float64)) + viaProxy = proverType == types.ProverProviderTypeProxy + } + var login types.LoginParameter if err := c.ShouldBind(&login); err != nil { return "", fmt.Errorf("missing the public_key, err:%w", err) } - // check login parameter's token is equal to bearer token, the Authorization must be existed - // if not exist, the jwt token will intercept it - brearToken := c.GetHeader("Authorization") - if brearToken != "Bearer "+login.Message.Challenge { - return "", errors.New("check challenge failure for the not equal challenge string") + // if not, process with normal login + if !viaProxy { + // check login parameter's token is equal to bearer token, the Authorization must be existed + // if not exist, the jwt token will intercept it + brearToken := c.GetHeader("Authorization") + if brearToken != "Bearer "+login.Message.Challenge { + return "", errors.New("check challenge failure for the not equal challenge string") + } + + if err := auth.VerifyMsg(&login); err != nil { + return "", err + } + + // check the challenge is used, if used, return failure + if err := a.loginLogic.InsertChallengeString(c, login.Message.Challenge); err != nil { + return "", fmt.Errorf("login insert challenge string failure:%w", err) + } } - if err := a.loginLogic.Check(&login); err != nil { + if err := a.loginLogic.CompatiblityCheck(&login); err != nil { return "", fmt.Errorf("check the login parameter failure: %w", err) } @@ -49,11 +77,6 @@ func (a *AuthController) Login(c *gin.Context) (interface{}, error) { return "", fmt.Errorf("prover hard fork name failure:%w", err) } - // check the challenge is used, if used, return failure - if err := a.loginLogic.InsertChallengeString(c, login.Message.Challenge); err != nil { - return "", fmt.Errorf("login insert challenge string failure:%w", err) - } - returnData := types.LoginParameterWithHardForkName{ HardForkName: hardForkNames, LoginParameter: login, @@ -85,10 +108,6 @@ func (a *AuthController) IdentityHandler(c *gin.Context) interface{} { c.Set(types.ProverName, proverName) } - if publicKey, ok := claims[types.PublicKey]; ok { - c.Set(types.PublicKey, publicKey) - } - if proverVersion, ok := claims[types.ProverVersion]; ok { c.Set(types.ProverVersion, proverVersion) } @@ -101,5 +120,9 @@ func (a *AuthController) IdentityHandler(c *gin.Context) interface{} { c.Set(types.ProverProviderTypeKey, providerType) } + if publicKey, ok := claims[types.PublicKey]; ok { + return publicKey + } + return nil } diff --git a/coordinator/internal/controller/proxy/auth.go b/coordinator/internal/controller/proxy/auth.go new file mode 100644 index 0000000000..a312384d77 --- /dev/null +++ b/coordinator/internal/controller/proxy/auth.go @@ -0,0 +1,137 @@ +package proxy + +import ( + "fmt" + + "time" + + jwt "github.com/appleboy/gin-jwt/v2" + "github.com/gin-gonic/gin" + "github.com/scroll-tech/go-ethereum/log" + + "scroll-tech/coordinator/internal/config" + "scroll-tech/coordinator/internal/controller/api" + "scroll-tech/coordinator/internal/logic/auth" + "scroll-tech/coordinator/internal/logic/verifier" + "scroll-tech/coordinator/internal/types" +) + +// AuthController is login API +type AuthController struct { + apiLogin *api.AuthController + clients Clients + proverMgr *ProverManager +} + +const upstreamConnTimeout = time.Second * 2 +const LoginParamCache = "login_param" +const ProverTypesKey = "prover_types" +const SignatureKey = "prover_signature" + +// NewAuthController returns an LoginController instance +func NewAuthController(cfg *config.ProxyConfig, clients Clients, proverMgr *ProverManager) *AuthController { + + // use a dummy Verifier to create login logic (we do not use any information in verifier) + dummyVf := verifier.Verifier{ + OpenVMVkMap: make(map[string]struct{}), + } + loginLogic := auth.NewLoginLogicWithSimpleDeduplicator(cfg.ProxyManager.Verifier, &dummyVf) + + authController := &AuthController{ + apiLogin: api.NewAuthControllerWithLogic(loginLogic), + clients: clients, + proverMgr: proverMgr, + } + + return authController +} + +// Login extended the Login hander in api controller +func (a *AuthController) Login(c *gin.Context) (interface{}, error) { + + loginRes, err := a.apiLogin.Login(c) + if err != nil { + return nil, err + } + loginParam := loginRes.(types.LoginParameterWithHardForkName) + + if loginParam.LoginParameter.Message.ProverProviderType == types.ProverProviderTypeProxy { + return nil, fmt.Errorf("proxy do not support recursive login") + } + + session := a.proverMgr.GetOrCreate(loginParam.PublicKey) + + for n, cli := range a.clients { + + go func(n string, cli Client) { + if err := session.ProxyLogin(c, cli, n, &loginParam.LoginParameter); err != nil { + log.Error("proxy login failed during token cache update", + "userKey", loginParam.PublicKey, + "upstream", n, + "error", err) + } + }(n, cli) + } + + return loginParam.LoginParameter, nil +} + +// PayloadFunc returns jwt.MapClaims with {public key, prover name}. +func (a *AuthController) PayloadFunc(data interface{}) jwt.MapClaims { + v, ok := data.(types.LoginParameter) + if !ok { + return jwt.MapClaims{} + } + + return jwt.MapClaims{ + types.PublicKey: v.PublicKey, + types.ProverName: v.Message.ProverName, + types.ProverVersion: v.Message.ProverVersion, + types.ProverProviderTypeKey: v.Message.ProverProviderType, + SignatureKey: v.Signature, + ProverTypesKey: v.Message.ProverTypes, + } +} + +// IdentityHandler replies to client for /login +func (a *AuthController) IdentityHandler(c *gin.Context) interface{} { + claims := jwt.ExtractClaims(c) + loginParam := &types.LoginParameter{} + + if proverName, ok := claims[types.ProverName]; ok { + loginParam.Message.ProverName, _ = proverName.(string) + } + + if proverVersion, ok := claims[types.ProverVersion]; ok { + loginParam.Message.ProverVersion, _ = proverVersion.(string) + } + + if providerType, ok := claims[types.ProverProviderTypeKey]; ok { + num, _ := providerType.(float64) + loginParam.Message.ProverProviderType = types.ProverProviderType(num) + } + + if signature, ok := claims[SignatureKey]; ok { + loginParam.Signature, _ = signature.(string) + } + + if proverTypes, ok := claims[ProverTypesKey]; ok { + arr, _ := proverTypes.([]any) + for _, elm := range arr { + num, _ := elm.(float64) + loginParam.Message.ProverTypes = append(loginParam.Message.ProverTypes, types.ProverType(num)) + } + } + + if publicKey, ok := claims[types.PublicKey]; ok { + loginParam.PublicKey, _ = publicKey.(string) + } + + if loginParam.PublicKey != "" { + + c.Set(LoginParamCache, loginParam) + return loginParam.PublicKey + } + + return nil +} diff --git a/coordinator/internal/controller/proxy/client.go b/coordinator/internal/controller/proxy/client.go new file mode 100644 index 0000000000..e3a6ebb456 --- /dev/null +++ b/coordinator/internal/controller/proxy/client.go @@ -0,0 +1,207 @@ +package proxy + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + ctypes "scroll-tech/common/types" + "scroll-tech/coordinator/internal/config" + "scroll-tech/coordinator/internal/types" +) + +// Client wraps an http client with a preset host for coordinator API calls +type upClient struct { + httpClient *http.Client + baseURL string + loginToken string +} + +// NewClient creates a new Client with the specified host +func newUpClient(cfg *config.UpStream) *upClient { + return &upClient{ + httpClient: &http.Client{ + Timeout: time.Duration(cfg.ConnectionTimeoutSec) * time.Second, + }, + baseURL: cfg.BaseUrl, + } +} + +func (c *upClient) Token() string { + return c.loginToken +} + +// need a parsable schema defination +type loginSchema struct { + Time string `json:"time"` + Token string `json:"token"` +} + +// FullLogin performs the complete login process: get challenge then login +func (c *upClient) Login(ctx context.Context, genLogin func(string) (*types.LoginParameter, error)) (*types.LoginSchema, error) { + // Step 1: Get challenge + url := fmt.Sprintf("%s/coordinator/v1/challenge", c.baseURL) + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create challenge request: %w", err) + } + + challengeResp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to get challenge: %w", err) + } + + parsedResp, err := handleHttpResp(challengeResp) + if err != nil { + return nil, err + } else if parsedResp.ErrCode != 0 { + return nil, fmt.Errorf("challenge failed: %d (%s)", parsedResp.ErrCode, parsedResp.ErrMsg) + } + + // Ste p2: Parse challenge response + var challengeSchema loginSchema + if err := parsedResp.DecodeData(&challengeSchema); err != nil { + return nil, fmt.Errorf("failed to parse challenge response: %w", err) + } + + // Step 3: Use the token from challenge as Bearer token for login + url = fmt.Sprintf("%s/coordinator/v1/login", c.baseURL) + + param, err := genLogin(challengeSchema.Token) + if err != nil { + return nil, fmt.Errorf("failed to setup login parameter: %w", err) + } + + jsonData, err := json.Marshal(param) + if err != nil { + return nil, fmt.Errorf("failed to marshal login parameter: %w", err) + } + + req, err = http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create login request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+challengeSchema.Token) + + loginResp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to perform login request: %w", err) + } + + parsedResp, err = handleHttpResp(loginResp) + if err != nil { + return nil, err + } else if parsedResp.ErrCode != 0 { + return nil, fmt.Errorf("login failed: %d (%s)", parsedResp.ErrCode, parsedResp.ErrMsg) + } + + var loginResult loginSchema + err = parsedResp.DecodeData(&loginResult) + if err != nil { + return nil, fmt.Errorf("login parsing data fail: %v", err) + } + c.loginToken = loginResult.Token + + // TODO: we need to parse time if we start making use of it + + return &types.LoginSchema{ + Token: loginResult.Token, + }, nil +} + +func handleHttpResp(resp *http.Response) (*ctypes.Response, error) { + if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusUnauthorized { + defer resp.Body.Close() + var respWithData ctypes.Response + // Note: Body is consumed after decoding, caller should not read it again + if err := json.NewDecoder(resp.Body).Decode(&respWithData); err == nil { + return &respWithData, nil + } else { + return nil, fmt.Errorf("login parsing expected response failed: %v", err) + } + + } + return nil, fmt.Errorf("login request failed with status: %d", resp.StatusCode) +} + +// ProxyLogin makes a POST request to /v1/proxy_login with LoginParameter +func (c *upClient) ProxyLogin(ctx context.Context, param *types.LoginParameter) (*ctypes.Response, error) { + url := fmt.Sprintf("%s/coordinator/v1/proxy_login", c.baseURL) + + jsonData, err := json.Marshal(param) + if err != nil { + return nil, fmt.Errorf("failed to marshal proxy login parameter: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create proxy login request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.loginToken) + + proxyLoginResp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to perform proxy login request: %w", err) + } + return handleHttpResp(proxyLoginResp) +} + +// GetTask makes a POST request to /v1/get_task with GetTaskParameter +func (c *upClient) GetTask(ctx context.Context, param *types.GetTaskParameter, token string) (*ctypes.Response, error) { + url := fmt.Sprintf("%s/coordinator/v1/get_task", c.baseURL) + + jsonData, err := json.Marshal(param) + if err != nil { + return nil, fmt.Errorf("failed to marshal get task parameter: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create get task request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + return handleHttpResp(resp) +} + +// SubmitProof makes a POST request to /v1/submit_proof with SubmitProofParameter +func (c *upClient) SubmitProof(ctx context.Context, param *types.SubmitProofParameter, token string) (*ctypes.Response, error) { + url := fmt.Sprintf("%s/coordinator/v1/submit_proof", c.baseURL) + + jsonData, err := json.Marshal(param) + if err != nil { + return nil, fmt.Errorf("failed to marshal submit proof parameter: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create submit proof request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + return handleHttpResp(resp) +} diff --git a/coordinator/internal/controller/proxy/client_manager.go b/coordinator/internal/controller/proxy/client_manager.go new file mode 100644 index 0000000000..ad170384f9 --- /dev/null +++ b/coordinator/internal/controller/proxy/client_manager.go @@ -0,0 +1,181 @@ +package proxy + +import ( + "context" + "crypto/ecdsa" + "fmt" + "sync" + "time" + + "github.com/scroll-tech/go-ethereum/common" + "github.com/scroll-tech/go-ethereum/crypto" + "github.com/scroll-tech/go-ethereum/log" + + "scroll-tech/coordinator/internal/config" + "scroll-tech/coordinator/internal/types" +) + +type Client interface { + Client(context.Context) *upClient + Reset(cli *upClient) +} + +type ClientManager struct { + name string + cliCfg *config.ProxyClient + cfg *config.UpStream + privKey *ecdsa.PrivateKey + + cachedCli struct { + sync.RWMutex + cli *upClient + completionCtx context.Context + } +} + +// transformToValidPrivateKey safely transforms arbitrary bytes into valid private key bytes +func buildPrivateKey(inputBytes []byte) (*ecdsa.PrivateKey, error) { + // Try appending bytes from 0x0 to 0x20 until we get a valid private key + for appendByte := byte(0x0); appendByte <= 0x20; appendByte++ { + // Append the byte to input + extendedBytes := append(inputBytes, appendByte) + + // Calculate 256-bit hash + hash := crypto.Keccak256(extendedBytes) + + // Try to create private key from hash + if k, err := crypto.ToECDSA(hash); err == nil { + return k, nil + } + } + + return nil, fmt.Errorf("failed to generate valid private key from input bytes") +} + +func NewClientManager(name string, cliCfg *config.ProxyClient, cfg *config.UpStream) (*ClientManager, error) { + + privKey, err := buildPrivateKey([]byte(cliCfg.Secret)) + if err != nil { + return nil, err + } + + return &ClientManager{ + name: name, + privKey: privKey, + cfg: cfg, + cliCfg: cliCfg, + }, nil +} + +func (cliMgr *ClientManager) doLogin(ctx context.Context, loginCli *upClient) time.Time { + // Calculate wait time between 2 seconds and cfg.RetryWaitTime + minWait := 2 * time.Second + waitDuration := time.Duration(cliMgr.cfg.RetryWaitTime) * time.Second + if waitDuration < minWait { + waitDuration = minWait + } + + for { + log.Info("attempting login to upstream coordinator", "name", cliMgr.name) + loginResult, err := loginCli.Login(ctx, cliMgr.genLoginParam) + if err == nil && loginResult != nil { + log.Info("login to upstream coordinator successful", "name", cliMgr.name, "time", loginResult.Time) + return loginResult.Time + } + log.Info("login to upstream coordinator failed, retrying", "name", cliMgr.name, "error", err, "waitDuration", waitDuration) + + timer := time.NewTimer(waitDuration) + select { + case <-ctx.Done(): + timer.Stop() + return time.Now() + case <-timer.C: + // Continue to next retry + } + } +} + +func (cliMgr *ClientManager) Reset(cli *upClient) { + cliMgr.cachedCli.Lock() + if cliMgr.cachedCli.cli == cli { + cliMgr.cachedCli.cli = nil + } + cliMgr.cachedCli.Unlock() + log.Info("cached client cleared", "name", cliMgr.name) +} + +func (cliMgr *ClientManager) Client(ctx context.Context) *upClient { + cliMgr.cachedCli.RLock() + if cliMgr.cachedCli.cli != nil { + defer cliMgr.cachedCli.RUnlock() + return cliMgr.cachedCli.cli + } + cliMgr.cachedCli.RUnlock() + + cliMgr.cachedCli.Lock() + if cliMgr.cachedCli.cli != nil { + defer cliMgr.cachedCli.Unlock() + return cliMgr.cachedCli.cli + } + + var completionCtx context.Context + // Check if completion context is set + if cliMgr.cachedCli.completionCtx != nil { + completionCtx = cliMgr.cachedCli.completionCtx + } else { + // Set new completion context and launch login goroutine + ctx, completionDone := context.WithCancel(context.TODO()) + loginCli := newUpClient(cliMgr.cfg) + completionCtx = context.WithValue(ctx, "cli", loginCli) + cliMgr.cachedCli.completionCtx = completionCtx + + // Launch keep-login goroutine + go func() { + defer completionDone() + cliMgr.doLogin(context.Background(), loginCli) + + cliMgr.cachedCli.Lock() + cliMgr.cachedCli.cli = loginCli + cliMgr.cachedCli.completionCtx = nil + + cliMgr.cachedCli.Unlock() + + }() + } + cliMgr.cachedCli.Unlock() + + // Wait for completion or request cancellation + select { + case <-ctx.Done(): + return nil + case <-completionCtx.Done(): + cli := completionCtx.Value("cli").(*upClient) + return cli + } +} + +func (cliMgr *ClientManager) genLoginParam(challenge string) (*types.LoginParameter, error) { + + // Generate public key string + publicKeyHex := common.Bytes2Hex(crypto.CompressPubkey(&cliMgr.privKey.PublicKey)) + + // Create login parameter with proxy settings + loginParam := &types.LoginParameter{ + Message: types.Message{ + Challenge: challenge, + ProverName: cliMgr.cliCfg.ProxyName, + ProverVersion: cliMgr.cliCfg.ProxyVersion, + ProverProviderType: types.ProverProviderTypeProxy, + ProverTypes: []types.ProverType{}, // Default empty + VKs: []string{}, // Default empty + }, + PublicKey: publicKeyHex, + } + + // Sign the message with the private key + if err := loginParam.SignWithKey(cliMgr.privKey); err != nil { + return nil, fmt.Errorf("failed to sign login parameter: %w", err) + } + + return loginParam, nil +} diff --git a/coordinator/internal/controller/proxy/controller.go b/coordinator/internal/controller/proxy/controller.go new file mode 100644 index 0000000000..0e1b217a3f --- /dev/null +++ b/coordinator/internal/controller/proxy/controller.go @@ -0,0 +1,43 @@ +package proxy + +import ( + "github.com/prometheus/client_golang/prometheus" + + "scroll-tech/coordinator/internal/config" +) + +var ( + // GetTask the prover task controller + GetTask *GetTaskController + // SubmitProof the submit proof controller + SubmitProof *SubmitProofController + // Auth the auth controller + Auth *AuthController +) + +// Clients manager a series of thread-safe clients for requesting upstream +// coordinators +type Clients map[string]Client + +// InitController inits Controller with database +func InitController(cfg *config.ProxyConfig, reg prometheus.Registerer) { + // normalize cfg + cfg.ProxyManager.Normalize() + + clients := make(map[string]Client) + + for nm, upCfg := range cfg.Coordinators { + cli, err := NewClientManager(nm, cfg.ProxyManager.Client, upCfg) + if err != nil { + panic("create new client fail") + } + clients[nm] = cli + } + + proverManager := NewProverManager() + priorityManager := NewPriorityUpstreamManager() + + Auth = NewAuthController(cfg, clients, proverManager) + GetTask = NewGetTaskController(cfg, clients, proverManager, priorityManager, reg) + SubmitProof = NewSubmitProofController(cfg, clients, proverManager, priorityManager, reg) +} diff --git a/coordinator/internal/controller/proxy/get_task.go b/coordinator/internal/controller/proxy/get_task.go new file mode 100644 index 0000000000..8261ff0f16 --- /dev/null +++ b/coordinator/internal/controller/proxy/get_task.go @@ -0,0 +1,171 @@ +package proxy + +import ( + "fmt" + "math/rand" + "sync" + + "github.com/gin-gonic/gin" + "github.com/prometheus/client_golang/prometheus" + "github.com/scroll-tech/go-ethereum/log" + + "scroll-tech/common/types" + + "scroll-tech/coordinator/internal/config" + coordinatorType "scroll-tech/coordinator/internal/types" +) + +func getSessionData(ctx *gin.Context) string { + + publicKeyData, publicKeyExist := ctx.Get(coordinatorType.PublicKey) + publicKey, castOk := publicKeyData.(string) + if !publicKeyExist || !castOk { + nerr := fmt.Errorf("no public key binding: %v", publicKeyData) + log.Warn("get_task parameter fail", "error", nerr) + + types.RenderFailure(ctx, types.ErrCoordinatorParameterInvalidNo, nerr) + return "" + } + + return publicKey +} + +// PriorityUpstreamManager manages priority upstream mappings with thread safety +type PriorityUpstreamManager struct { + sync.RWMutex + data map[string]string +} + +// NewPriorityUpstreamManager creates a new PriorityUpstreamManager +func NewPriorityUpstreamManager() *PriorityUpstreamManager { + return &PriorityUpstreamManager{ + data: make(map[string]string), + } +} + +// Get retrieves the priority upstream for a given key +func (p *PriorityUpstreamManager) Get(key string) (string, bool) { + p.RLock() + defer p.RUnlock() + value, exists := p.data[key] + return value, exists +} + +// Set sets the priority upstream for a given key +func (p *PriorityUpstreamManager) Set(key, value string) { + p.Lock() + defer p.Unlock() + p.data[key] = value +} + +// Delete removes the priority upstream for a given key +func (p *PriorityUpstreamManager) Delete(key string) { + p.Lock() + defer p.Unlock() + delete(p.data, key) +} + +// GetTaskController the get prover task api controller +type GetTaskController struct { + proverMgr *ProverManager + clients Clients + priorityUpstream *PriorityUpstreamManager + + workingRnd *rand.Rand + getTaskAccessCounter *prometheus.CounterVec +} + +// NewGetTaskController create a get prover task controller +func NewGetTaskController(cfg *config.ProxyConfig, clients Clients, proverMgr *ProverManager, priorityMgr *PriorityUpstreamManager, reg prometheus.Registerer) *GetTaskController { + // TODO: implement proxy get task controller initialization + return &GetTaskController{ + priorityUpstream: priorityMgr, + proverMgr: proverMgr, + clients: clients, + } +} + +func (ptc *GetTaskController) incGetTaskAccessCounter(ctx *gin.Context) error { + // TODO: implement proxy get task access counter + return nil +} + +// GetTasks get assigned chunk/batch task +func (ptc *GetTaskController) GetTasks(ctx *gin.Context) { + + var getTaskParameter coordinatorType.GetTaskParameter + if err := ctx.ShouldBind(&getTaskParameter); err != nil { + nerr := fmt.Errorf("prover task parameter invalid, err:%w", err) + types.RenderFailure(ctx, types.ErrCoordinatorParameterInvalidNo, nerr) + return + } + + publicKey := getSessionData(ctx) + if publicKey == "" { + return + } + + session := ptc.proverMgr.Get(publicKey) + + getTask := func(upStream string, cli Client) (tryNext bool) { + resp, err := session.GetTask(ctx, &getTaskParameter, cli, upStream) + if err != nil { + types.RenderFailure(ctx, types.ErrCoordinatorGetTaskFailure, err) + return + } else if resp.ErrCode != types.ErrCoordinatorEmptyProofData { + + if resp.ErrCode != 0 { + // simply dispatch the error from upstream to prover + types.RenderFailure(ctx, resp.ErrCode, fmt.Errorf("%s", resp.ErrMsg)) + return + } + + var task coordinatorType.GetTaskSchema + if err = resp.DecodeData(&task); err == nil { + task.TaskID = formUpstreamWithTaskName(upStream, task.TaskID) + ptc.priorityUpstream.Set(publicKey, upStream) + // TODO: log the new id in debug level + types.RenderSuccess(ctx, &task) + } else { + types.RenderFailure(ctx, types.InternalServerError, fmt.Errorf("decode task fail: %v", err)) + } + + return + } + tryNext = true + return + } + + // if the priority upsteam is set, we try this upstream first until get the task resp or no task resp + priorityUpstream, exist := ptc.priorityUpstream.Get(publicKey) + if exist { + cli := ptc.clients[priorityUpstream] + if cli != nil && !getTask(priorityUpstream, cli) { + return + } else if cli == nil { + // TODO: log error + } + } + ptc.priorityUpstream.Delete(publicKey) + + // Create a slice to hold the keys + keys := make([]string, 0, len(ptc.clients)) + for k := range ptc.clients { + keys = append(keys, k) + } + + // Shuffle the keys using a local RNG (avoid deprecated rand.Seed) + rand.Shuffle(len(keys), func(i, j int) { + keys[i], keys[j] = keys[j], keys[i] + }) + + // Iterate over the shuffled keys + for _, n := range keys { + if !getTask(n, ptc.clients[n]) { + return + } + } + + // if all get task failed, throw empty proof resp + types.RenderFailure(ctx, types.ErrCoordinatorEmptyProofData, fmt.Errorf("get empty prover task")) +} diff --git a/coordinator/internal/controller/proxy/prover_session.go b/coordinator/internal/controller/proxy/prover_session.go new file mode 100644 index 0000000000..eb41fbe3d4 --- /dev/null +++ b/coordinator/internal/controller/proxy/prover_session.go @@ -0,0 +1,235 @@ +package proxy + +import ( + "context" + "fmt" + "math" + "sync" + "time" + + ctypes "scroll-tech/common/types" + "scroll-tech/coordinator/internal/types" +) + +type ProverManager struct { + sync.RWMutex + data map[string]*proverSession +} + +func NewProverManager() *ProverManager { + return &ProverManager{ + data: make(map[string]*proverSession), + } +} + +// get retrieves ProverSession for a given user key, returns empty if still not exists +func (m *ProverManager) Get(userKey string) *proverSession { + m.RLock() + defer m.RUnlock() + + return m.data[userKey] +} + +func (m *ProverManager) GetOrCreate(userKey string) *proverSession { + m.Lock() + defer m.Unlock() + + if ret, ok := m.data[userKey]; ok { + return ret + } + + ret := &proverSession{ + proverToken: make(map[string]loginToken), + } + + m.data[userKey] = ret + return ret +} + +func (m *ProverManager) Set(userKey string, session *proverSession) { + m.Lock() + defer m.Unlock() + + m.data[userKey] = session +} + +type loginToken struct { + *types.LoginSchema + phase uint +} + +// Client wraps an http client with a preset host for coordinator API calls +type proverSession struct { + sync.RWMutex + proverToken map[string]loginToken + completionCtx context.Context +} + +func (c *proverSession) maintainLogin(ctx context.Context, cliMgr Client, up string, param *types.LoginParameter, phase uint) (result *types.LoginSchema, nerr error) { + c.Lock() + curPhase := c.proverToken[up].phase + if c.completionCtx != nil { + waitctx := c.completionCtx + c.Unlock() + select { + case <-waitctx.Done(): + return c.maintainLogin(ctx, cliMgr, up, param, phase) + case <-ctx.Done(): + return nil, fmt.Errorf("ctx fail") + } + } + + if phase < curPhase { + // outdate login phase, give up + defer c.Unlock() + return c.proverToken[up].LoginSchema, nil + } + + // occupy the update slot + completeCtx, cf := context.WithCancel(ctx) + defer cf() + c.completionCtx = completeCtx + defer func() { + c.Lock() + c.completionCtx = nil + if result != nil { + c.proverToken[up] = loginToken{ + LoginSchema: result, + phase: curPhase + 1, + } + } + c.Unlock() + }() + c.Unlock() + + cli := cliMgr.Client(ctx) + if cli == nil { + return nil, fmt.Errorf("get upstream cli fail") + } + + resp, err := cli.ProxyLogin(ctx, param) + if err != nil { + return nil, fmt.Errorf("proxylogin fail: %v", err) + } + + if resp.ErrCode == ctypes.ErrJWTTokenExpired { + cliMgr.Reset(cli) + cli = cliMgr.Client(ctx) + if cli == nil { + return nil, fmt.Errorf("get upstream cli fail (secondary try)") + } + + // like SDK, we would try one more time if the upstream token is expired + resp, err = cli.ProxyLogin(ctx, param) + if err != nil { + return nil, fmt.Errorf("proxylogin fail: %v", err) + } + } + + if resp.ErrCode != 0 { + return nil, fmt.Errorf("upstream fail: %d (%s)", resp.ErrCode, resp.ErrMsg) + } + + var loginResult loginSchema + if err := resp.DecodeData(&loginResult); err != nil { + return nil, err + } + + return &types.LoginSchema{ + Token: loginResult.Token, + }, nil +} + +const expireTolerant = 10 * time.Minute + +// ProxyLogin makes a POST request to /v1/proxy_login with LoginParameter +func (c *proverSession) ProxyLogin(ctx context.Context, cli Client, up string, param *types.LoginParameter) error { + c.RLock() + existedToken := c.proverToken[up].LoginSchema + c.RUnlock() + + // Check if we have a valid cached token that hasn't expired + if existedToken != nil { + // TODO: how to reduce the unnecessary re-login? + // timeRemaining := time.Until(existedToken.Time) + // if timeRemaining > expireTolerant { + // return nil + // } + } + + _, err := c.maintainLogin(ctx, cli, up, param, math.MaxUint) + return err +} + +// GetTask makes a POST request to /v1/get_task with GetTaskParameter +func (c *proverSession) GetTask(ctx context.Context, param *types.GetTaskParameter, cliMgr Client, up string) (*ctypes.Response, error) { + c.RLock() + token := c.proverToken[up] + c.RUnlock() + + cli := cliMgr.Client(ctx) + if cli == nil { + return nil, fmt.Errorf("get upstream cli fail") + } + + if token.LoginSchema != nil { + resp, err := cli.GetTask(ctx, param, token.Token) + if err != nil { + return nil, err + } + if resp.ErrCode != ctypes.ErrJWTTokenExpired { + return resp, nil + } + } + + // like SDK, we would try one more time if the upstream token is expired + // get param from ctx + loginParam, ok := ctx.Value(LoginParamCache).(*types.LoginParameter) + if !ok { + return nil, fmt.Errorf("Unexpected error, no loginparam ctx value") + } + + newToken, err := c.maintainLogin(ctx, cliMgr, up, loginParam, token.phase) + if err != nil { + return nil, fmt.Errorf("update prover token fail: %V", err) + } + + return cli.GetTask(ctx, param, newToken.Token) + +} + +// SubmitProof makes a POST request to /v1/submit_proof with SubmitProofParameter +func (c *proverSession) SubmitProof(ctx context.Context, param *types.SubmitProofParameter, cliMgr Client, up string) (*ctypes.Response, error) { + c.RLock() + token := c.proverToken[up] + c.RUnlock() + + cli := cliMgr.Client(ctx) + if cli == nil { + return nil, fmt.Errorf("get upstream cli fail") + } + + if token.LoginSchema != nil { + resp, err := cli.SubmitProof(ctx, param, token.Token) + if err != nil { + return nil, err + } + if resp.ErrCode != ctypes.ErrJWTTokenExpired { + return resp, nil + } + } + + // like SDK, we would try one more time if the upstream token is expired + // get param from ctx + loginParam, ok := ctx.Value(LoginParamCache).(*types.LoginParameter) + if !ok { + return nil, fmt.Errorf("Unexpected error, no loginparam ctx value") + } + + newToken, err := c.maintainLogin(ctx, cliMgr, up, loginParam, token.phase) + if err != nil { + return nil, fmt.Errorf("update prover token fail: %V", err) + } + + return cli.SubmitProof(ctx, param, newToken.Token) +} diff --git a/coordinator/internal/controller/proxy/submit_proof.go b/coordinator/internal/controller/proxy/submit_proof.go new file mode 100644 index 0000000000..90c582620b --- /dev/null +++ b/coordinator/internal/controller/proxy/submit_proof.go @@ -0,0 +1,82 @@ +package proxy + +import ( + "fmt" + "strings" + + "github.com/gin-gonic/gin" + "github.com/prometheus/client_golang/prometheus" + + "scroll-tech/common/types" + "scroll-tech/coordinator/internal/config" + coordinatorType "scroll-tech/coordinator/internal/types" +) + +// SubmitProofController the submit proof api controller +type SubmitProofController struct { + proverMgr *ProverManager + clients Clients + priorityUpstream *PriorityUpstreamManager +} + +// NewSubmitProofController create the submit proof api controller instance +func NewSubmitProofController(cfg *config.ProxyConfig, clients Clients, proverMgr *ProverManager, priorityMgr *PriorityUpstreamManager, reg prometheus.Registerer) *SubmitProofController { + return &SubmitProofController{ + proverMgr: proverMgr, + clients: clients, + priorityUpstream: priorityMgr, + } +} + +func upstreamFromTaskName(taskID string) (string, string) { + parts, rest, found := strings.Cut(taskID, ":") + if found { + return parts, rest + } + return "", parts +} + +func formUpstreamWithTaskName(upstream string, taskID string) string { + return fmt.Sprintf("%s:%s", upstream, taskID) +} + +// SubmitProof prover submit the proof to coordinator +func (spc *SubmitProofController) SubmitProof(ctx *gin.Context) { + + var submitParameter coordinatorType.SubmitProofParameter + if err := ctx.ShouldBind(&submitParameter); err != nil { + nerr := fmt.Errorf("prover submitProof parameter invalid, err:%w", err) + types.RenderFailure(ctx, types.ErrCoordinatorParameterInvalidNo, nerr) + return + } + + publicKey := getSessionData(ctx) + if publicKey == "" { + return + } + + session := spc.proverMgr.Get(publicKey) + upstream, realTaskID := upstreamFromTaskName(submitParameter.TaskID) + cli, existed := spc.clients[upstream] + if !existed { + // TODO: log error + nerr := fmt.Errorf("Invalid upstream name (%s) from taskID %s", upstream, submitParameter.TaskID) + types.RenderFailure(ctx, types.ErrCoordinatorParameterInvalidNo, nerr) + return + } + submitParameter.TaskID = realTaskID + + resp, err := session.SubmitProof(ctx, &submitParameter, cli, upstream) + if err != nil { + types.RenderFailure(ctx, types.ErrCoordinatorGetTaskFailure, err) + return + } else if resp.ErrCode != 0 { + // simply dispatch the error from upstream to prover + types.RenderFailure(ctx, resp.ErrCode, fmt.Errorf("%s", resp.ErrMsg)) + return + } else { + spc.priorityUpstream.Delete(upstream) + types.RenderSuccess(ctx, resp.Data) + return + } +} diff --git a/coordinator/internal/logic/auth/login.go b/coordinator/internal/logic/auth/login.go index 28cb3b8a18..f691eb0e85 100644 --- a/coordinator/internal/logic/auth/login.go +++ b/coordinator/internal/logic/auth/login.go @@ -1,6 +1,7 @@ package auth import ( + "context" "errors" "fmt" "strings" @@ -19,45 +20,72 @@ import ( // LoginLogic the auth logic type LoginLogic struct { - cfg *config.Config - challengeOrm *orm.Challenge + cfg *config.VerifierConfig + deduplicator ChallengeDeduplicator openVmVks map[string]struct{} proverVersionHardForkMap map[string]string } +type ChallengeDeduplicator interface { + InsertChallenge(ctx context.Context, challengeString string) error +} + +type SimpleDeduplicator struct { +} + +func (s *SimpleDeduplicator) InsertChallenge(ctx context.Context, challengeString string) error { + return nil +} + +// NewLoginLogicWithSimpleDEduplicator new a LoginLogic, do not use db to deduplicate challege +func NewLoginLogicWithSimpleDeduplicator(vcfg *config.VerifierConfig, vf *verifier.Verifier) *LoginLogic { + return newLoginLogic(&SimpleDeduplicator{}, vcfg, vf) +} + // NewLoginLogic new a LoginLogic -func NewLoginLogic(db *gorm.DB, cfg *config.Config, vf *verifier.Verifier) *LoginLogic { +func NewLoginLogic(db *gorm.DB, vcfg *config.VerifierConfig, vf *verifier.Verifier) *LoginLogic { + return newLoginLogic(orm.NewChallenge(db), vcfg, vf) +} + +func newLoginLogic(deduplicator ChallengeDeduplicator, vcfg *config.VerifierConfig, vf *verifier.Verifier) *LoginLogic { + proverVersionHardForkMap := make(map[string]string) - for _, cfg := range cfg.ProverManager.Verifier.Verifiers { + for _, cfg := range vcfg.Verifiers { proverVersionHardForkMap[cfg.ForkName] = cfg.MinProverVersion } return &LoginLogic{ - cfg: cfg, + cfg: vcfg, openVmVks: vf.OpenVMVkMap, - challengeOrm: orm.NewChallenge(db), + deduplicator: deduplicator, proverVersionHardForkMap: proverVersionHardForkMap, } } -// InsertChallengeString insert and check the challenge string is existed -func (l *LoginLogic) InsertChallengeString(ctx *gin.Context, challenge string) error { - return l.challengeOrm.InsertChallenge(ctx.Copy(), challenge) -} - -func (l *LoginLogic) Check(login *types.LoginParameter) error { +// Verify the completeness of login message +func VerifyMsg(login *types.LoginParameter) error { verify, err := login.Verify() if err != nil || !verify { log.Error("auth message verify failure", "prover_name", login.Message.ProverName, "prover_version", login.Message.ProverVersion, "message", login.Message) return errors.New("auth message verify failure") } + return nil +} - if !version.CheckScrollRepoVersion(login.Message.ProverVersion, l.cfg.ProverManager.Verifier.MinProverVersion) { - return fmt.Errorf("incompatible prover version. please upgrade your prover, minimum allowed version: %s, actual version: %s", l.cfg.ProverManager.Verifier.MinProverVersion, login.Message.ProverVersion) +// InsertChallengeString insert and check the challenge string is existed +func (l *LoginLogic) InsertChallengeString(ctx *gin.Context, challenge string) error { + return l.deduplicator.InsertChallenge(ctx.Copy(), challenge) +} + +// Check if the login client is compatible with the setting in coordinator +func (l *LoginLogic) CompatiblityCheck(login *types.LoginParameter) error { + + if !version.CheckScrollRepoVersion(login.Message.ProverVersion, l.cfg.MinProverVersion) { + return fmt.Errorf("incompatible prover version. please upgrade your prover, minimum allowed version: %s, actual version: %s", l.cfg.MinProverVersion, login.Message.ProverVersion) } vks := make(map[string]struct{}) @@ -65,27 +93,32 @@ func (l *LoginLogic) Check(login *types.LoginParameter) error { vks[vk] = struct{}{} } - for _, vk := range login.Message.VKs { - if _, ok := vks[vk]; !ok { - log.Error("vk inconsistency", "prover vk", vk, "prover name", login.Message.ProverName, - "prover_version", login.Message.ProverVersion, "message", login.Message) - if !version.CheckScrollProverVersion(login.Message.ProverVersion) { - return fmt.Errorf("incompatible prover version. please upgrade your prover, expect version: %s, actual version: %s", - version.Version, login.Message.ProverVersion) + // new coordinator / proxy do not check vks while login, code only for backward compatibility + if len(vks) != 0 { + for _, vk := range login.Message.VKs { + if _, ok := vks[vk]; !ok { + log.Error("vk inconsistency", "prover vk", vk, "prover name", login.Message.ProverName, + "prover_version", login.Message.ProverVersion, "message", login.Message) + if !version.CheckScrollProverVersion(login.Message.ProverVersion) { + return fmt.Errorf("incompatible prover version. please upgrade your prover, expect version: %s, actual version: %s", + version.Version, login.Message.ProverVersion) + } + // if the prover reports a same prover version + return errors.New("incompatible vk. please check your params files or config files") } - // if the prover reports a same prover version - return errors.New("incompatible vk. please check your params files or config files") } } - if login.Message.ProverProviderType != types.ProverProviderTypeInternal && login.Message.ProverProviderType != types.ProverProviderTypeExternal { + switch login.Message.ProverProviderType { + case types.ProverProviderTypeInternal: + case types.ProverProviderTypeExternal: + case types.ProverProviderTypeProxy: + case types.ProverProviderTypeUndefined: // for backward compatibility, set ProverProviderType as internal - if login.Message.ProverProviderType == types.ProverProviderTypeUndefined { - login.Message.ProverProviderType = types.ProverProviderTypeInternal - } else { - log.Error("invalid prover_provider_type", "value", login.Message.ProverProviderType, "prover name", login.Message.ProverName, "prover version", login.Message.ProverVersion) - return errors.New("invalid prover provider type.") - } + login.Message.ProverProviderType = types.ProverProviderTypeInternal + default: + log.Error("invalid prover_provider_type", "value", login.Message.ProverProviderType, "prover name", login.Message.ProverName, "prover version", login.Message.ProverVersion) + return errors.New("invalid prover provider type.") } return nil diff --git a/coordinator/internal/middleware/challenge_jwt.go b/coordinator/internal/middleware/challenge_jwt.go index 6ee8254b07..99a58cc8db 100644 --- a/coordinator/internal/middleware/challenge_jwt.go +++ b/coordinator/internal/middleware/challenge_jwt.go @@ -14,7 +14,7 @@ import ( ) // ChallengeMiddleware jwt challenge middleware -func ChallengeMiddleware(conf *config.Config) *jwt.GinJWTMiddleware { +func ChallengeMiddleware(auth *config.Auth) *jwt.GinJWTMiddleware { jwtMiddleware, err := jwt.New(&jwt.GinJWTMiddleware{ Authenticator: func(c *gin.Context) (interface{}, error) { return nil, nil @@ -30,8 +30,8 @@ func ChallengeMiddleware(conf *config.Config) *jwt.GinJWTMiddleware { } }, Unauthorized: unauthorized, - Key: []byte(conf.Auth.Secret), - Timeout: time.Second * time.Duration(conf.Auth.ChallengeExpireDurationSec), + Key: []byte(auth.Secret), + Timeout: time.Second * time.Duration(auth.ChallengeExpireDurationSec), TokenLookup: "header: Authorization, query: token, cookie: jwt", TokenHeadName: "Bearer", TimeFunc: time.Now, diff --git a/coordinator/internal/middleware/login_jwt.go b/coordinator/internal/middleware/login_jwt.go index b04810b0b7..66d9702ac8 100644 --- a/coordinator/internal/middleware/login_jwt.go +++ b/coordinator/internal/middleware/login_jwt.go @@ -4,22 +4,60 @@ import ( "time" jwt "github.com/appleboy/gin-jwt/v2" + "github.com/gin-gonic/gin" "github.com/scroll-tech/go-ethereum/log" "scroll-tech/coordinator/internal/config" "scroll-tech/coordinator/internal/controller/api" + "scroll-tech/coordinator/internal/controller/proxy" "scroll-tech/coordinator/internal/types" ) +func nonIdendityAuthorizator(data interface{}, _ *gin.Context) bool { + if data == nil { + return false + } + return true +} + // LoginMiddleware jwt auth middleware -func LoginMiddleware(conf *config.Config) *jwt.GinJWTMiddleware { +func LoginMiddleware(auth *config.Auth) *jwt.GinJWTMiddleware { jwtMiddleware, err := jwt.New(&jwt.GinJWTMiddleware{ PayloadFunc: api.Auth.PayloadFunc, IdentityHandler: api.Auth.IdentityHandler, IdentityKey: types.PublicKey, - Key: []byte(conf.Auth.Secret), - Timeout: time.Second * time.Duration(conf.Auth.LoginExpireDurationSec), + Key: []byte(auth.Secret), + Timeout: time.Second * time.Duration(auth.LoginExpireDurationSec), Authenticator: api.Auth.Login, + Authorizator: nonIdendityAuthorizator, + Unauthorized: unauthorized, + TokenLookup: "header: Authorization, query: token, cookie: jwt", + TokenHeadName: "Bearer", + TimeFunc: time.Now, + LoginResponse: loginResponse, + }) + + if err != nil { + log.Crit("new jwt middleware panic", "error", err) + } + + if errInit := jwtMiddleware.MiddlewareInit(); errInit != nil { + log.Crit("init jwt middleware panic", "error", errInit) + } + + return jwtMiddleware +} + +// ProxyLoginMiddleware jwt auth middleware for proxy login +func ProxyLoginMiddleware(auth *config.Auth) *jwt.GinJWTMiddleware { + jwtMiddleware, err := jwt.New(&jwt.GinJWTMiddleware{ + PayloadFunc: proxy.Auth.PayloadFunc, + IdentityHandler: proxy.Auth.IdentityHandler, + IdentityKey: types.PublicKey, + Key: []byte(auth.Secret), + Timeout: time.Second * time.Duration(auth.LoginExpireDurationSec), + Authenticator: proxy.Auth.Login, + Authorizator: nonIdendityAuthorizator, Unauthorized: unauthorized, TokenLookup: "header: Authorization, query: token, cookie: jwt", TokenHeadName: "Bearer", diff --git a/coordinator/internal/middleware/proxy_bearer.go b/coordinator/internal/middleware/proxy_bearer.go new file mode 100644 index 0000000000..c870d7c164 --- /dev/null +++ b/coordinator/internal/middleware/proxy_bearer.go @@ -0,0 +1 @@ +package middleware diff --git a/coordinator/internal/route/route.go b/coordinator/internal/route/route.go index 9e9eef076e..5d9a7c65a4 100644 --- a/coordinator/internal/route/route.go +++ b/coordinator/internal/route/route.go @@ -8,6 +8,7 @@ import ( "scroll-tech/coordinator/internal/config" "scroll-tech/coordinator/internal/controller/api" + "scroll-tech/coordinator/internal/controller/proxy" "scroll-tech/coordinator/internal/middleware" ) @@ -25,16 +26,45 @@ func Route(router *gin.Engine, cfg *config.Config, reg prometheus.Registerer) { func v1(router *gin.RouterGroup, conf *config.Config) { r := router.Group("/v1") - challengeMiddleware := middleware.ChallengeMiddleware(conf) + challengeMiddleware := middleware.ChallengeMiddleware(conf.Auth) r.GET("/challenge", challengeMiddleware.LoginHandler) - loginMiddleware := middleware.LoginMiddleware(conf) + loginMiddleware := middleware.LoginMiddleware(conf.Auth) r.POST("/login", challengeMiddleware.MiddlewareFunc(), loginMiddleware.LoginHandler) // need jwt token api r.Use(loginMiddleware.MiddlewareFunc()) { + r.POST("/proxy_login", loginMiddleware.LoginHandler) r.POST("/get_task", api.GetTask.GetTasks) r.POST("/submit_proof", api.SubmitProof.SubmitProof) } } + +// Route register route for coordinator +func ProxyRoute(router *gin.Engine, cfg *config.ProxyConfig, reg prometheus.Registerer) { + router.Use(gin.Recovery()) + + observability.Use(router, "coordinator", reg) + + r := router.Group("coordinator") + + v1_proxy(r, cfg) +} + +func v1_proxy(router *gin.RouterGroup, conf *config.ProxyConfig) { + r := router.Group("/v1") + + challengeMiddleware := middleware.ChallengeMiddleware(conf.ProxyManager.Auth) + r.GET("/challenge", challengeMiddleware.LoginHandler) + + loginMiddleware := middleware.ProxyLoginMiddleware(conf.ProxyManager.Auth) + r.POST("/login", challengeMiddleware.MiddlewareFunc(), loginMiddleware.LoginHandler) + + // need jwt token api + r.Use(loginMiddleware.MiddlewareFunc()) + { + r.POST("/get_task", proxy.GetTask.GetTasks) + r.POST("/submit_proof", proxy.SubmitProof.SubmitProof) + } +} diff --git a/coordinator/internal/types/prover.go b/coordinator/internal/types/prover.go index 048fac00a2..4254c673d5 100644 --- a/coordinator/internal/types/prover.go +++ b/coordinator/internal/types/prover.go @@ -64,6 +64,8 @@ func (r ProverProviderType) String() string { return "prover provider type internal" case ProverProviderTypeExternal: return "prover provider type external" + case ProverProviderTypeProxy: + return "prover provider type proxy" default: return fmt.Sprintf("prover provider type: %d", r) } @@ -76,4 +78,6 @@ const ( ProverProviderTypeInternal // ProverProviderTypeExternal is an external prover provider type ProverProviderTypeExternal + // ProverProviderTypeProxy is an proxy prover provider type + ProverProviderTypeProxy = 3 ) diff --git a/coordinator/internal/types/response_test.go b/coordinator/internal/types/response_test.go new file mode 100644 index 0000000000..6508d870a6 --- /dev/null +++ b/coordinator/internal/types/response_test.go @@ -0,0 +1,48 @@ +package types + +import ( + "encoding/json" + "reflect" + "testing" + + "scroll-tech/common/types" +) + +func TestResponseDecodeData_GetTaskSchema(t *testing.T) { + // Arrange: build a dummy payload and wrap it in Response + in := GetTaskSchema{ + UUID: "uuid-123", + TaskID: "task-abc", + TaskType: 1, + UseSnark: true, + TaskData: "dummy-data", + HardForkName: "cancun", + } + + resp := types.Response{ + ErrCode: 0, + ErrMsg: "", + Data: in, + } + + // Act: JSON round-trip the Response to simulate real HTTP encoding/decoding + b, err := json.Marshal(resp) + if err != nil { + t.Fatalf("marshal response: %v", err) + } + + var decoded types.Response + if err := json.Unmarshal(b, &decoded); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + var out GetTaskSchema + if err := decoded.DecodeData(&out); err != nil { + t.Fatalf("DecodeData error: %v", err) + } + + // Assert: structs match after decode + if !reflect.DeepEqual(in, out) { + t.Fatalf("decoded struct mismatch:\nwant: %+v\n got: %+v", in, out) + } +} diff --git a/coordinator/test/api_test.go b/coordinator/test/api_test.go index 053f6b715e..f435dd16b6 100644 --- a/coordinator/test/api_test.go +++ b/coordinator/test/api_test.go @@ -51,6 +51,8 @@ var ( chunk *encoding.Chunk batch *encoding.Batch tokenTimeout int + + envSet bool ) func TestMain(m *testing.M) { @@ -67,6 +69,25 @@ func randomURL() string { return fmt.Sprintf("localhost:%d", 10000+2000+id.Int64()) } +// Generate a batch of random localhost URLs with different ports, similar to randomURL. +func randmURLBatch(n int) []string { + if n <= 0 { + return nil + } + urls := make([]string, 0, n) + used := make(map[int64]struct{}, n) + for len(urls) < n { + id, _ := rand.Int(rand.Reader, big.NewInt(2000-1)) + port := 10000 + 2000 + id.Int64() + if _, ok := used[port]; ok { + continue + } + used[port] = struct{}{} + urls = append(urls, fmt.Sprintf("localhost:%d", port)) + } + return urls +} + func setupCoordinator(t *testing.T, proversPerSession uint8, coordinatorURL string) (*cron.Collector, *http.Server) { var err error db, err = testApps.GetGormDBClient() @@ -130,6 +151,11 @@ func setupCoordinator(t *testing.T, proversPerSession uint8, coordinatorURL stri } func setEnv(t *testing.T) { + if envSet { + t.Log("SetEnv is re-entried") + return + } + var err error version.Version = "v4.4.89" @@ -169,6 +195,7 @@ func setEnv(t *testing.T) { assert.NoError(t, err) batch = &encoding.Batch{Chunks: []*encoding.Chunk{chunk}} + envSet = true } func TestApis(t *testing.T) { diff --git a/coordinator/test/mock_prover.go b/coordinator/test/mock_prover.go index 0076199b33..958c230547 100644 --- a/coordinator/test/mock_prover.go +++ b/coordinator/test/mock_prover.go @@ -191,7 +191,7 @@ func (r *mockProver) tryGetProverTask(t *testing.T, proofType message.ProofType) resp, err := client.R(). SetHeader("Content-Type", "application/json"). SetHeader("Authorization", fmt.Sprintf("Bearer %s", token)). - SetBody(map[string]interface{}{"prover_height": 100, "task_type": int(proofType), "universal": true}). + SetBody(map[string]interface{}{"prover_height": 100, "task_types": []int{int(proofType)}, "universal": true}). SetResult(&result). Post("http://" + r.coordinatorURL + "/coordinator/v1/get_task") assert.NoError(t, err) diff --git a/coordinator/test/proxy_test.go b/coordinator/test/proxy_test.go new file mode 100644 index 0000000000..b8a09afbac --- /dev/null +++ b/coordinator/test/proxy_test.go @@ -0,0 +1,273 @@ +package test + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/scroll-tech/da-codec/encoding" + "github.com/stretchr/testify/assert" + + "scroll-tech/common/types" + "scroll-tech/common/types/message" + "scroll-tech/common/version" + + "scroll-tech/coordinator/internal/config" + "scroll-tech/coordinator/internal/controller/proxy" + "scroll-tech/coordinator/internal/route" +) + +func testProxyClientCfg() *config.ProxyClient { + + return &config.ProxyClient{ + Secret: "test-secret-key", + ProxyName: "test-proxy", + ProxyVersion: version.Version, + } +} + +func testProxyUpStreamCfg(coordinatorURL string) *config.UpStream { + + return &config.UpStream{ + BaseUrl: fmt.Sprintf("http://%s", coordinatorURL), + RetryWaitTime: 3, + ConnectionTimeoutSec: 30, + } + +} + +func testProxyClient(t *testing.T) { + + // Setup coordinator and http server. + coordinatorURL := randomURL() + proofCollector, httpHandler := setupCoordinator(t, 1, coordinatorURL) + defer func() { + proofCollector.Stop() + assert.NoError(t, httpHandler.Shutdown(context.Background())) + }() + + cliCfg := testProxyClientCfg() + upCfg := testProxyUpStreamCfg(coordinatorURL) + + clientManager, err := proxy.NewClientManager("test_coordinator", cliCfg, upCfg) + assert.NoError(t, err) + assert.NotNil(t, clientManager) + + // Create context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Test Client method + client := clientManager.Client(ctx) + + // Client should not be nil if login succeeds + // Note: This might be nil if the coordinator is not properly set up for proxy authentication + // but the test validates that the Client method completes without panic + assert.NotNil(t, client) + assert.NotEmpty(t, client.Token()) + t.Logf("Client token: %s (%v)", client.Token(), client) +} + +var ( + proxyConf *config.ProxyConfig +) + +func setupProxy(t *testing.T, proxyURL string, coordinatorURL []string) *http.Server { + var err error + assert.NoError(t, err) + + coordinators := make(map[string]*config.UpStream) + for i, n := range coordinatorURL { + coordinators[fmt.Sprintf("coordinator_%d", i)] = testProxyUpStreamCfg(n) + } + + tokenTimeout = 60 + proxyConf = &config.ProxyConfig{ + ProxyName: "test_proxy", + ProxyManager: &config.ProxyManager{ + Verifier: &config.VerifierConfig{ + MinProverVersion: "v4.4.89", + Verifiers: []config.AssetConfig{{ + AssetsPath: "", + ForkName: "euclidV2", + }}, + }, + Client: testProxyClientCfg(), + Auth: &config.Auth{ + Secret: "proxy", + ChallengeExpireDurationSec: tokenTimeout, + LoginExpireDurationSec: tokenTimeout, + }, + }, + Coordinators: coordinators, + } + + router := gin.New() + proxy.InitController(proxyConf, nil) + route.ProxyRoute(router, proxyConf, nil) + srv := &http.Server{ + Addr: proxyURL, + Handler: router, + } + go func() { + runErr := srv.ListenAndServe() + if runErr != nil && !errors.Is(runErr, http.ErrServerClosed) { + assert.NoError(t, runErr) + } + }() + time.Sleep(time.Second * 2) + + return srv +} + +func testProxyHandshake(t *testing.T) { + // Setup proxy http server. + proxyURL := randomURL() + proxyHttpHandler := setupProxy(t, proxyURL, []string{}) + defer func() { + assert.NoError(t, proxyHttpHandler.Shutdown(context.Background())) + }() + + chunkProver := newMockProver(t, "prover_chunk_test", proxyURL, message.ProofTypeChunk, version.Version) + assert.True(t, chunkProver.healthCheckSuccess(t)) +} + +func testProxyGetTask(t *testing.T) { + // Setup coordinator and http server. + urls := randmURLBatch(2) + coordinatorURL := urls[0] + collector, httpHandler := setupCoordinator(t, 3, coordinatorURL) + defer func() { + collector.Stop() + assert.NoError(t, httpHandler.Shutdown(context.Background())) + }() + + proxyURL := urls[1] + proxyHttpHandler := setupProxy(t, proxyURL, []string{coordinatorURL}) + defer func() { + assert.NoError(t, proxyHttpHandler.Shutdown(context.Background())) + }() + + chunkProver := newMockProver(t, "prover_chunk_test", proxyURL, message.ProofTypeChunk, version.Version) + code, msg := chunkProver.tryGetProverTask(t, message.ProofTypeChunk) + assert.Equal(t, int(types.ErrCoordinatorEmptyProofData), code) + + err := l2BlockOrm.InsertL2Blocks(context.Background(), []*encoding.Block{block1, block2}) + assert.NoError(t, err) + dbChunk, err := chunkOrm.InsertChunk(context.Background(), chunk) + assert.NoError(t, err) + err = l2BlockOrm.UpdateChunkHashInRange(context.Background(), 0, 100, dbChunk.Hash) + assert.NoError(t, err) + + task, code, msg := chunkProver.getProverTask(t, message.ProofTypeChunk) + assert.Empty(t, code) + if code == 0 { + t.Log("get task id", task.TaskID) + } else { + t.Log("get task error msg", msg) + } + +} + +func testProxyProof(t *testing.T) { + urls := randmURLBatch(3) + coordinatorURL0 := urls[0] + collector0, httpHandler0 := setupCoordinator(t, 3, coordinatorURL0) + defer func() { + collector0.Stop() + httpHandler0.Shutdown(context.Background()) + }() + coordinatorURL1 := urls[1] + collector1, httpHandler1 := setupCoordinator(t, 3, coordinatorURL1) + defer func() { + collector1.Stop() + httpHandler1.Shutdown(context.Background()) + }() + coordinators := map[string]*http.Server{ + "coordinator_0": httpHandler0, + "coordinator_1": httpHandler1, + } + + proxyURL := urls[2] + proxyHttpHandler := setupProxy(t, proxyURL, []string{coordinatorURL0, coordinatorURL1}) + defer func() { + fmt.Println("px end start") + assert.NoError(t, proxyHttpHandler.Shutdown(context.Background())) + fmt.Println("px end") + }() + + err := l2BlockOrm.InsertL2Blocks(context.Background(), []*encoding.Block{block1, block2}) + assert.NoError(t, err) + dbChunk, err := chunkOrm.InsertChunk(context.Background(), chunk) + assert.NoError(t, err) + err = l2BlockOrm.UpdateChunkHashInRange(context.Background(), 0, 100, dbChunk.Hash) + assert.NoError(t, err) + + chunkProver := newMockProver(t, "prover_chunk_test", proxyURL, message.ProofTypeChunk, version.Version) + task, code, msg := chunkProver.getProverTask(t, message.ProofTypeChunk) + assert.Empty(t, code) + if code == 0 { + t.Log("get task", task) + parts, _, _ := strings.Cut(task.TaskID, ":") + // close the coordinator which do not dispatch task first, so if we submit to wrong target, + // there would be a chance the submit failed (to the closed coordinator) + for n, srv := range coordinators { + if n != parts { + t.Log("close coordinator", n) + assert.NoError(t, srv.Shutdown(context.Background())) + } + } + exceptProofStatus := verifiedSuccess + chunkProver.submitProof(t, task, exceptProofStatus, types.Success) + + } else { + t.Log("get task error msg", msg) + } + + // verify proof status + var ( + tick = time.Tick(1500 * time.Millisecond) + tickStop = time.Tick(time.Minute) + ) + + var ( + chunkProofStatus types.ProvingStatus + chunkActiveAttempts int16 + chunkMaxAttempts int16 + ) + + for { + select { + case <-tick: + chunkProofStatus, err = chunkOrm.GetProvingStatusByHash(context.Background(), dbChunk.Hash) + assert.NoError(t, err) + if chunkProofStatus == types.ProvingTaskVerified { + return + } + + chunkActiveAttempts, chunkMaxAttempts, err = chunkOrm.GetAttemptsByHash(context.Background(), dbChunk.Hash) + assert.NoError(t, err) + assert.Equal(t, 1, int(chunkMaxAttempts)) + assert.Equal(t, 0, int(chunkActiveAttempts)) + + case <-tickStop: + t.Error("failed to check proof status", "chunkProofStatus", chunkProofStatus.String()) + return + } + } +} + +func TestProxyClient(t *testing.T) { + + // Set up the test environment. + setEnv(t) + t.Run("TestProxyClient", testProxyClient) + t.Run("TestProxyHandshake", testProxyHandshake) + t.Run("TestProxyGetTask", testProxyGetTask) + t.Run("TestProxyValidProof", testProxyProof) +}