Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ NODE_HOST = "0.0.0.0"

SSL_CERT_FILE = /var/lib/gozargah-node/certs/ssl_cert.pem
SSL_KEY_FILE = /var/lib/gozargah-node/certs/ssl_key.pem
SSL_CLIENT_CERT_FILE = /var/lib/gozargah-node/certs/ssl_client_cert.pem

# api key must be a valid uuid (you can use any version you want)
API_KEY = xxxxxxxx-yyyy-zzzz-mmmm-aaaaaaaaaaa

### can be rest or grpc
# SERVICE_PROTOCOL = grpc
# MAX_LOG_PER_REQUEST = 1000

### for developers
# DEBUG = false
Expand Down
360 changes: 170 additions & 190 deletions common/service.pb.go

Large diffs are not rendered by default.

2 changes: 0 additions & 2 deletions common/service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ message BaseInfoResponse {
bool started = 1;
string core_version = 2;
string node_version = 3;
string session_id = 4;
string extra = 5;
}

enum BackendType {
Expand Down
27 changes: 19 additions & 8 deletions config/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package config

import (
"github.com/google/uuid"
"log"
"os"
"regexp"
Expand All @@ -20,10 +21,12 @@ func init() {
XrayAssetsPath = GetEnv("XRAY_ASSETS_PATH", "/usr/local/share/xray")
SslCertFile = GetEnv("SSL_CERT_FILE", "/var/lib/gozargah-node/certs/ssl_cert.pem")
SslKeyFile = GetEnv("SSL_KEY_FILE", "/var/lib/gozargah-node/certs/ssl_key.pem")
SslClientCertFile = GetEnv("SSL_CLIENT_CERT_FILE", "/var/lib/gozargah-node/certs/ssl_client_cert.pem")
ApiKey, err = GetEnvAsUUID("API_KEY")
if err != nil {
log.Printf("[Error] Faild to load API Key, error: %v", err)
}
GeneratedConfigPath = GetEnv("GENERATED_CONFIG_PATH", "/var/lib/gozargah-node/generated/")
ServiceProtocol = GetEnv("SERVICE_PROTOCOL", "grpc")
MaxLogPerRequest = GetEnvAsInt("MAX_LOG_PER_REQUEST", 1000)
Debug = GetEnvAsBool("DEBUG", false)
nodeHostStr := GetEnv("NODE_HOST", "0.0.0.0")

Expand All @@ -42,17 +45,16 @@ func init() {
}

// Warning: only use in tests
func SetEnv(port, maxLogPerRequest int, host, xrayExecutablePath, xrayAssetsPath, sslCertFile, sslKeyFile, sslClientCertFile,
serviceProtocol, generatedConfigPath string, debug bool) {
func SetEnv(port int, host, xrayExecutablePath, xrayAssetsPath, sslCertFile, sslKeyFile,
serviceProtocol, generatedConfigPath string, apiKey uuid.UUID, debug bool) {
ServicePort = port
NodeHost = host
XrayExecutablePath = xrayExecutablePath
XrayAssetsPath = xrayAssetsPath
SslCertFile = sslCertFile
SslKeyFile = sslKeyFile
SslClientCertFile = sslClientCertFile
ApiKey = apiKey
ServiceProtocol = serviceProtocol
MaxLogPerRequest = maxLogPerRequest
GeneratedConfigPath = generatedConfigPath
Debug = debug
}
Expand Down Expand Up @@ -81,16 +83,25 @@ func GetEnvAsInt(name string, defaultVal int) int {
return defaultVal
}

func GetEnvAsUUID(name string) (uuid.UUID, error) {
valStr := GetEnv(name, "")

val, err := uuid.Parse(valStr)
if err != nil {
return uuid.Nil, err
}
return val, nil
}

var (
ServicePort int
NodeHost string
XrayExecutablePath string
XrayAssetsPath string
SslCertFile string
SslKeyFile string
SslClientCertFile string
ApiKey uuid.UUID
ServiceProtocol string
MaxLogPerRequest int
Debug bool
GeneratedConfigPath string
)
16 changes: 5 additions & 11 deletions controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type Service interface {

type Controller struct {
backend backend.Backend
sessionID uuid.UUID
ApiKey uuid.UUID
apiPort int
clientIP string
lastRequest time.Time
Expand All @@ -36,21 +36,20 @@ type Controller struct {
func (c *Controller) Init() {
c.mu.Lock()
defer c.mu.Unlock()
c.sessionID = uuid.Nil
c.ApiKey = config.ApiKey
c.apiPort = tools.FindFreePort()
_, c.cancelFunc = context.WithCancel(context.Background())
}

func (c *Controller) GetSessionID() uuid.UUID {
func (c *Controller) GetApiKey() uuid.UUID {
c.mu.RLock()
defer c.mu.RUnlock()
return c.sessionID
return c.ApiKey
}

func (c *Controller) Connect(ip string, keepAlive uint64) {
c.mu.Lock()
defer c.mu.Unlock()
c.sessionID = uuid.New()
c.lastRequest = time.Now()
c.clientIP = ip

Expand All @@ -76,7 +75,6 @@ func (c *Controller) Disconnect() {
apiPort := tools.FindFreePort()
c.apiPort = apiPort

c.sessionID = uuid.Nil
c.clientIP = ""
}

Expand Down Expand Up @@ -160,24 +158,20 @@ func (c *Controller) GetStats() *common.SystemStatsResponse {
return c.stats
}

func (c *Controller) BaseInfoResponse(includeID bool, extra string) *common.BaseInfoResponse {
func (c *Controller) BaseInfoResponse() *common.BaseInfoResponse {
c.mu.Lock()
defer c.mu.Unlock()

response := &common.BaseInfoResponse{
Started: false,
CoreVersion: "",
NodeVersion: NodeVersion,
Extra: extra,
}

if c.backend != nil {
response.Started = c.backend.Started()
response.CoreVersion = c.backend.GetVersion()
}
if includeID {
response.SessionId = c.sessionID.String()
}

return response
}
7 changes: 2 additions & 5 deletions controller/rest/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

func (s *Service) Base(w http.ResponseWriter, _ *http.Request) {
common.SendProtoResponse(w, s.BaseInfoResponse(false, ""))
common.SendProtoResponse(w, s.BaseInfoResponse())
}

func (s *Service) Start(w http.ResponseWriter, r *http.Request) {
Expand All @@ -36,18 +36,15 @@ func (s *Service) Start(w http.ResponseWriter, r *http.Request) {

s.Connect(ip, keepAlive)

log.Println(ip, " connected, Session ID = ", s.GetSessionID())

if err = s.StartBackend(ctx, backendType); err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}

common.SendProtoResponse(w, s.BaseInfoResponse(true, ""))
common.SendProtoResponse(w, s.BaseInfoResponse())
}

func (s *Service) Stop(w http.ResponseWriter, _ *http.Request) {
log.Println(s.GetIP(), " disconnected, Session ID = ", s.GetSessionID())
s.Disconnect()

common.SendProtoResponse(w, &common.Empty{})
Expand Down
50 changes: 12 additions & 38 deletions controller/rest/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,56 +2,30 @@ package rest

import (
"fmt"
"log"
"net"
"net/http"
"strings"

"github.com/go-chi/chi/v5/middleware"
"github.com/google/uuid"
"log"
"net/http"
)

func (s *Service) checkSessionIDMiddleware(next http.Handler) http.Handler {
func (s *Service) validateApiKey(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// check ip
clientIP := s.GetIP()
clientID := s.GetSessionID()
if clientIP == "" || clientID == uuid.Nil {
http.Error(w, "please connect first", http.StatusTooEarly)
return
}

// check ip
ip, _, err := net.SplitHostPort(r.RemoteAddr)
switch {
case err != nil:
http.Error(w, err.Error(), http.StatusBadRequest)
return
case ip != s.GetIP():
http.Error(w, "IP address is not valid", http.StatusForbidden)
return
}

authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, "please connect first", http.StatusUnauthorized)
apiKeyHeader := r.Header.Get("x-api-key")
if apiKeyHeader == "" {
http.Error(w, "missing x-api-key header", http.StatusUnauthorized)
return
}

parts := strings.Split(authHeader, " ")
if len(parts) != 2 {
http.Error(w, "invalid Authorization header format", http.StatusUnauthorized)
return
}
// check API key
apiKey := s.GetApiKey()

tokenString := parts[1]
sessionID, err := uuid.Parse(tokenString)
key, err := uuid.Parse(apiKeyHeader)
switch {
case err != nil:
http.Error(w, "please send valid uuid", http.StatusUnprocessableEntity)
http.Error(w, "invalid api key format: must be a valid UUID", http.StatusUnprocessableEntity)
return
case sessionID != clientID:
http.Error(w, "session id mismatch.", http.StatusForbidden)
case key != apiKey:
http.Error(w, "api key mismatch", http.StatusForbidden)
return
}

Expand Down
29 changes: 8 additions & 21 deletions controller/rest/rest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ var (
xrayAssetsPath = "/usr/local/share/xray"
sslCertFile = "../../certs/ssl_cert.pem"
sslKeyFile = "../../certs/ssl_key.pem"
sslClientCertFile = "../../certs/ssl_client_cert.pem"
sslClientKeyFile = "../../certs/ssl_client_key.pem"
apiKey = uuid.New()
generatedConfigPath = "../../generated/"
addr = fmt.Sprintf("%s:%d", nodeHost, servicePort)
configPath = "../../backend/xray/config.json"
Expand All @@ -52,8 +51,8 @@ func createHTTPClient(tlsConfig *tls.Config) *http.Client {
}

func TestRESTConnection(t *testing.T) {
config.SetEnv(servicePort, 1000, nodeHost, xrayExecutablePath, xrayAssetsPath,
sslCertFile, sslKeyFile, sslClientCertFile, "rest", generatedConfigPath, true)
config.SetEnv(servicePort, nodeHost, xrayExecutablePath, xrayAssetsPath,
sslCertFile, sslKeyFile, "rest", generatedConfigPath, apiKey, true)

nodeLogger.SetOutputMode(true)

Expand All @@ -65,12 +64,7 @@ func TestRESTConnection(t *testing.T) {
}
}

clientFileExists := tools.FileExists(sslClientCertFile)
if !clientFileExists {
t.Fatal("SSL_CLIENT_CERT_FILE is required.")
}

tlsConfig, err := tools.LoadTLSCredentials(sslCertFile, sslKeyFile, sslClientCertFile, false)
tlsConfig, err := tools.LoadTLSCredentials(sslCertFile, sslKeyFile)
if err != nil {
t.Fatal(err)
}
Expand All @@ -81,16 +75,14 @@ func TestRESTConnection(t *testing.T) {
}
defer s.Disconnect()

creds, err := tools.LoadTLSCredentials(sslClientCertFile, sslClientKeyFile, sslCertFile, true)
certPool, err := tools.LoadClientPool(sslCertFile)
if err != nil {
t.Fatal(err)
}
client := tools.CreateHTTPClient(certPool, nodeHost)

url := fmt.Sprintf("https://%s", addr)

client := createHTTPClient(creds)
sessionId := ""

createAuthenticatedRequest := func(method, endpoint string, data proto.Message, response proto.Message) error {
body, err := proto.Marshal(data)
if err != nil {
Expand All @@ -101,7 +93,7 @@ func TestRESTConnection(t *testing.T) {
if err != nil {
return err
}
req.Header.Set("Authorization", "Bearer "+sessionId)
req.Header.Set("x-api-key", apiKey.String())
if body != nil {
req.Header.Set("Content-Type", "application/x-protobuf")
}
Expand All @@ -124,7 +116,7 @@ func TestRESTConnection(t *testing.T) {
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+sessionId)
req.Header.Set("x-api-key", apiKey.String())

resp, err := client.Do(req)
if err != nil {
Expand Down Expand Up @@ -207,11 +199,6 @@ func TestRESTConnection(t *testing.T) {
t.Fatalf("Failed to start backend: %v", err)
}

sessionId = baseInfoResp.GetSessionId()
if sessionId == "" {
t.Fatal("No session ID received")
}

var stats common.StatResponse
// Try To Get Outbounds Stats
if err = createAuthenticatedRequest("GET", "/stats/outbounds", &common.StatRequest{Reset_: true}, &stats); err != nil {
Expand Down
Loading