diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6e4ec00..e4038dd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,6 +27,17 @@ jobs: - name: Run Tests run: make GOBIN=$HOME/gopath/bin test + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.13' + + - name: Create Python Virtual Environment + run: python -m venv .pyvenv + + - name: Run API Unit Tests + run: make api-test + - name: Run Coverage Tests env: COVERALLS_TOKEN: ${{ secrets.COVERALLS_TOKEN }} @@ -112,12 +123,15 @@ jobs: env: # Set ROLE_ARN based on the branch ROLE_ARN: ${{ github.ref == 'refs/heads/master' && secrets.PROD_LAMBDA_ROLE_ARN || secrets.STAGING_LAMBDA_ROLE_ARN }} + # Set GitHub OAuth credentials based on the branch + GITHUB_OAUTH_CLIENT_ID: ${{ github.ref == 'refs/heads/master' && secrets.PROD_GITHUB_OAUTH_CLIENT_ID || secrets.STAGING_GITHUB_OAUTH_CLIENT_ID }} + GITHUB_OAUTH_CLIENT_SECRET: ${{ github.ref == 'refs/heads/master' && secrets.PROD_GITHUB_OAUTH_CLIENT_SECRET || secrets.STAGING_GITHUB_OAUTH_CLIENT_SECRET }} run: | cd gogen-api if [ "${{ github.ref }}" = "refs/heads/master" ]; then - # Prod deployment, ROLE_ARN is already set via env + # Prod deployment bash deploy_lambdas.sh else - # Staging deployment, ROLE_ARN is already set via env + # Staging deployment bash deploy_lambdas.sh -e staging fi diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index bc1e9f6..477966f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -172,6 +172,8 @@ jobs: - name: Deploy Lambda Functions env: ROLE_ARN: ${{ secrets.PROD_LAMBDA_ROLE_ARN }} + GITHUB_OAUTH_CLIENT_ID: ${{ secrets.PROD_GITHUB_OAUTH_CLIENT_ID }} + GITHUB_OAUTH_CLIENT_SECRET: ${{ secrets.PROD_GITHUB_OAUTH_CLIENT_SECRET }} run: | cd gogen-api bash deploy_lambdas.sh \ No newline at end of file diff --git a/.gitignore b/.gitignore index e47ec5c..492854c 100644 --- a/.gitignore +++ b/.gitignore @@ -21,11 +21,13 @@ roveralls* .specstory .pyvenv gogen-api/__pycache__ +__pycache__/ gogen-api/build +gogen-api/env.json ui/node_modules/* ui/dist/* ui/build/* ui/coverage/* ui/public/gogen.wasm ui/.vite -*.idea \ No newline at end of file +*.idea diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..791a3d8 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,116 @@ +# AGENTS.md + +This file gives coding agents the repo-specific context needed to work effectively in `gogen`. + +## Project Overview + +Gogen is a data generator for demo and test data, especially time-series logs and metrics. The repo contains: + +- a Go CLI core +- a Python AWS Lambda API backend in `gogen-api/` +- a React/TypeScript UI in `ui/` + +## Common Commands + +### Go + +```bash +make install # Preferred install path; injects ldflags from Makefile +make build # Cross-compiles linux, darwin, windows, wasm +make test # go test -v ./... +go test -v ./internal +go test -v -run TestName ./internal +``` + +Notes: + +- Use `make install` instead of bare `go install`; version/build metadata and OAuth settings are injected through `-ldflags`. +- Dependencies are vendored. After dependency changes, run `go mod vendor`. + +### Python API + +```bash +cd gogen-api +./start_dev.sh +./setup_local_db.sh +./deploy_lambdas.sh +``` + +Repo-standard Python environment: + +```bash +source /home/clint/local/src/gogen/.pyvenv/bin/activate +``` + +Focused API unit tests: + +```bash +make api-test +``` + +### UI + +```bash +cd ui +npm run dev +npm run build +npm test +``` + +## Architecture + +### Go Package Layout + +- `main.go`: CLI entry point using `urfave/cli.v1`; maps flags to `GOGEN_*` env vars +- `internal/`: core config, sample, token, API/share logic +- `generator/`: generation workers +- `outputter/`: output workers and destinations +- `run/`: pipeline orchestration +- `timer/`: one timer goroutine per sample +- `rater/`: event-rate control +- `template/`: output formatting +- `logger/`: log wrapper + +### Data Flow + +```text +YAML/JSON config -> internal.Config singleton + -> timer goroutines + -> generator worker pool + -> outputter worker pool + -> output destination +``` + +### Config System + +- Config is a singleton guarded by `sync.Once` +- Remote configs default to `https://api.gogen.io` and can be overridden by `GOGEN_APIURL` +- In Go tests, reset config state with `config.ResetConfig()` before `config.NewConfig()` +- Tests often use `config.SetupFromString(...)` for inline YAML + +### Python API + +- Lambda handlers live as separate files in `gogen-api/api/` +- Backed by DynamoDB + S3 +- Local development uses Docker Compose plus SAM +- Use `.pyvenv` rather than system Python when running repo Python commands + +### UI + +- Vite + React 18 + TypeScript + Tailwind +- Components live in `ui/src/components/` +- Pages live in `ui/src/pages/` +- Tests are colocated as `.test.tsx` + +## CI/CD + +- `.github/workflows/ci.yml` runs Go tests on pushes to `master`/`dev` and on PRs +- CI also runs `make api-test` +- Branch builds/deploys happen on `master` and `dev` +- Release workflow is handled separately in `.github/workflows/release.yml` + +## Practical Notes + +- Prefer minimal, targeted edits; this repo spans Go, Python, and frontend code in one tree +- For Python work, prefer adding tests that avoid external AWS dependencies unless the task explicitly needs integration coverage +- For UI tests, keep them aligned with the current design system rather than hardcoding old color classes diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..4b48d2a --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,102 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Gogen is an open source data generator for generating demo and test data, especially time series log and metric data. It's a Go CLI tool with an embedded Lua scripting engine, a Python AWS Lambda API backend, and a React/TypeScript UI. + +## Common Commands + +### Go (core CLI) + +```bash +make install # Build and install to $GOPATH/bin (default target) +make build # Cross-compile for linux, darwin, windows, wasm +make test # Run all Go tests: go test -v ./... +go test -v ./internal # Run tests for a single package +go test -v -run TestName ./internal # Run a single test +``` + +Version, git summary, build date, and GitHub OAuth credentials are injected via `-ldflags` in the Makefile. Always use `make install` rather than bare `go install`. + +Dependencies are vendored in `vendor/`. After adding deps, run `go mod vendor`. + +### Python API (`gogen-api/`) + +```bash +cd gogen-api +./start_dev.sh # Starts DynamoDB Local + MinIO via docker-compose, then SAM local API on port 4000 +./setup_local_db.sh # Seeds local DynamoDB schema +sam build && sam local start-api --port 4000 --docker-network lambda-local +./deploy_lambdas.sh # Deploy to AWS (requires credentials) +``` + +### UI (`ui/`) + +```bash +cd ui +npm run dev # Vite dev server (copies wasm from build/wasm/ first) +npm run build # Production build +npm test # Jest tests +``` + +## Architecture + +### Go Package Layout + +All packages are at the top level (no `cmd/` or `pkg/` convention): + +- **`main.go`** — CLI entry point using `urfave/cli.v1`. Maps CLI flags to `GOGEN_*` env vars. +- **`internal/`** — Core package. Config singleton, `Sample` struct, `Token` processing, API client, sharing. Imported as `config` throughout (`config "github.com/coccyx/gogen/internal"`). +- **`generator/`** — Reads `GenQueueItem` from channel, dispatches to sample-based or Lua generators. +- **`outputter/`** — Reads `OutQueueItem` from channel, dispatches to output destinations (stdout, file, HTTP, Kafka, network, devnull, buf). +- **`run/`** — Orchestrates the pipeline: timers -> generator worker pool -> outputter worker pool. +- **`timer/`** — One timer goroutine per Sample; handles backfill and realtime intervals. +- **`rater/`** — Controls event rate (config-based, time-of-day/weekday, kbps, Lua script). +- **`template/`** — Output formatting (raw, JSON, CSV, splunkhec, syslog, elasticsearch). +- **`logger/`** — Thin logrus wrapper with file/func/line context hook. + +### Data Flow + +``` +YAML/JSON Config -> internal.Config singleton (sync.Once) + -> [Timer goroutine per Sample] + -> GenQueueItem channel -> [Generator worker pool] + -> OutQueueItem channel -> [Outputter worker pool] + -> output destination +``` + +Concurrency is channel + goroutine worker pools. Worker counts set by `GeneratorWorkers` and `OutputWorkers` config fields. + +### Key Interfaces + +- `internal.Generator` — `Gen(item *GenQueueItem) error` +- `internal.Outputter` — `Send(events []map[string]string, sample *Sample, outputTemplate string) error` +- `internal.Rater` — `EventsPerInterval(s *Sample) int` + +### Config System + +Config is a **singleton** via `sync.Once`. Controlled by environment variables: +- `GOGEN_HOME`, `GOGEN_FULLCONFIG`, `GOGEN_CONFIG_DIR`, `GOGEN_SAMPLES_DIR` +- Remote configs fetched from `https://api.gogen.io` (override with `GOGEN_APIURL`) + +In tests, call `config.ResetConfig()` before `config.NewConfig()` to get a fresh instance. Tests commonly use `config.SetupFromString(yamlStr)` to inject inline YAML config. + +### gogen-api (Python Lambda) + +Each Lambda function is a separate `.py` file in `gogen-api/`. Backed by DynamoDB + S3. Originally Python 2.7, being updated to Python 3. AWS SAM template at `gogen-api/template.yaml`. + +### UI (React/TypeScript) + +Vite + React 18 + TypeScript + Tailwind CSS. Components in `src/components/`, pages in `src/pages/`, API clients in `src/api/`, types in `src/types/`. Tests use Jest + React Testing Library, placed adjacent to source as `.test.tsx`. + +## CI/CD + +GitHub Actions (`.github/workflows/ci.yml`): +- Push to `master`/`dev` or any PR: runs `make test`, then on `master`/`dev` cross-compiles, builds Docker, pushes artifacts to S3, deploys UI and Lambdas. +- Tag pushes (`v*.*.*`): full release workflow via `release.yml` — builds, creates GitHub release, pushes Docker images, deploys to production. + +## Lua Scripting + +Generators (`generator/lua.go`) and raters (`rater/script.go`) support embedded Lua via `gopher-lua` + `gopher-luar`. Lua state persists across calls within a run. diff --git a/Makefile b/Makefile index 97ffbc4..7aaf2df 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ SUMMARY = $(shell git describe --tags --always --dirty) DATE = $(shell date --rfc-3339=date) -.PHONY: all build deps install test docker splunkapp embed +.PHONY: all build deps install test api-test docker splunkapp embed ifeq ($(OS),Windows_NT) dockercmd := docker run -e TERM -e HOME=/go/src/github.com/coccyx/gogen --rm -it -v $(CURDIR):/go/src/github.com/coccyx/gogen -v $(HOME)/.ssh:/root/.ssh clintsharp/gogen bash @@ -33,6 +33,8 @@ install: test: go test -v ./... +api-test: + ./.pyvenv/bin/python -m unittest gogen-api/test_auth_utils.py gogen-api/test_upsert_auth.py + docker: $(dockercmd) - diff --git a/VERSION b/VERSION index 34a8361..54d1a4f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.12.1 +0.13.0 diff --git a/generator/generator_test.go b/generator/generator_test.go index 73e2e9a..7568314 100644 --- a/generator/generator_test.go +++ b/generator/generator_test.go @@ -12,23 +12,25 @@ import ( "github.com/stretchr/testify/assert" ) -func TestGenerator(t *testing.T) { - // Setup environment +// setupGenTest resets config, sets env vars, and returns common test fixtures. +func setupGenTest(t *testing.T, samplesDir string, seed int64) (func() time.Time, *rand.Rand) { + t.Helper() + config.ResetConfig() os.Setenv("GOGEN_HOME", "..") os.Setenv("GOGEN_ALWAYS_REFRESH", "1") os.Setenv("GOGEN_FULLCONFIG", "") - home := filepath.Join("..", "tests", "tokens") - os.Setenv("GOGEN_SAMPLES_DIR", home) + os.Setenv("GOGEN_SAMPLES_DIR", samplesDir) loc, _ := time.LoadLocation("Local") - source := rand.NewSource(0) - randgen := rand.New(source) - + randgen := rand.New(rand.NewSource(seed)) n := time.Date(2001, 10, 20, 12, 0, 0, 100000, loc) - now := func() time.Time { - return n - } + now := func() time.Time { return n } + return now, randgen +} + +func TestGenerator(t *testing.T) { + home := filepath.Join("..", "tests", "tokens") + now, randgen := setupGenTest(t, home, 0) - // gq := make(chan *config.GenQueueItem) oq := make(chan *config.OutQueueItem) s := tests.FindSampleInFile(home, "token-static") if s == nil { @@ -44,23 +46,171 @@ func TestGenerator(t *testing.T) { assert.Equal(t, "foo", oqi.Events[0]["_raw"]) } -func TestGeneratorCache(t *testing.T) { - // Setup environment +func TestGeneratorMultiPass(t *testing.T) { + home := filepath.Join("..", "tests", "tokens") + now, randgen := setupGenTest(t, home, 0) + + oq := make(chan *config.OutQueueItem) + s := tests.FindSampleInFile(home, "tokens") + if s == nil { + t.Fatalf("Sample tokens not found") + } + // Force MultiPass + s.SinglePass = false + + // Count > lines: tests the iters > 1 path in genMultiPass + gqi := &config.GenQueueItem{Count: len(s.Lines) + 2, Earliest: now(), Latest: now(), Now: now(), S: s, OQ: oq, Rand: randgen, Cache: &config.CacheItem{}} + go func() { + err := genMultiPass(gqi) + assert.NoError(t, err) + }() + + oqi := <-oq + assert.Equal(t, len(s.Lines)+2, len(oqi.Events)) +} + +func TestGeneratorMultiPassRandomize(t *testing.T) { + home := filepath.Join("..", "tests", "tokens") + now, randgen := setupGenTest(t, home, 42) + + oq := make(chan *config.OutQueueItem) + s := tests.FindSampleInFile(home, "tokens") + if s == nil { + t.Fatalf("Sample tokens not found") + } + s.SinglePass = false + s.RandomizeEvents = true + + gqi := &config.GenQueueItem{Count: 5, Earliest: now(), Latest: now(), Now: now(), S: s, OQ: oq, Rand: randgen, Cache: &config.CacheItem{}} + go func() { + genMultiPass(gqi) + }() + + oqi := <-oq + assert.Equal(t, 5, len(oqi.Events)) +} + +func TestGeneratorSinglePassCountGtLines(t *testing.T) { + home := filepath.Join("..", "tests", "singlepass") + now, randgen := setupGenTest(t, filepath.Join(home, "test1.yml"), 0) + + c := config.NewConfig() + s := c.FindSampleByName("test1") + if s == nil { + t.Fatalf("Sample test1 not found") + } + assert.True(t, s.SinglePass) + + oq := make(chan *config.OutQueueItem) + // Count > lines: tests the iters > 1 singlepass path + gqi := &config.GenQueueItem{Count: len(s.Lines) + 3, Earliest: now(), Latest: now(), Now: now(), S: s, OQ: oq, Rand: randgen, Cache: &config.CacheItem{}} + go func() { + genSinglePass(gqi) + }() + + oqi := <-oq + assert.Equal(t, len(s.Lines)+3, len(oqi.Events)) +} + +func TestGeneratorSinglePassRandomize(t *testing.T) { + home := filepath.Join("..", "tests", "singlepass") + now, randgen := setupGenTest(t, filepath.Join(home, "test1.yml"), 42) + + c := config.NewConfig() + s := c.FindSampleByName("test1") + if s == nil { + t.Fatalf("Sample test1 not found") + } + assert.True(t, s.SinglePass) + s.RandomizeEvents = true + + oq := make(chan *config.OutQueueItem) + gqi := &config.GenQueueItem{Count: 5, Earliest: now(), Latest: now(), Now: now(), S: s, OQ: oq, Rand: randgen, Cache: &config.CacheItem{}} + go func() { + genSinglePass(gqi) + }() + + oqi := <-oq + assert.Equal(t, 5, len(oqi.Events)) +} + +func TestGeneratorStartWorker(t *testing.T) { + home := filepath.Join("..", "tests", "tokens") + now, randgen := setupGenTest(t, home, 0) + + oq := make(chan *config.OutQueueItem) + s := tests.FindSampleInFile(home, "token-static") + if s == nil { + t.Fatalf("Sample token-static not found") + } + + gq := make(chan *config.GenQueueItem) + gqs := make(chan int) + go Start(gq, gqs) + + // Send multiple items to test the "generator already set" path + for i := 0; i < 3; i++ { + gqi := &config.GenQueueItem{Count: 1, Earliest: now(), Latest: now(), Now: now(), S: s, OQ: oq, Rand: randgen, Cache: &config.CacheItem{}} + gq <- gqi + oqi := <-oq + assert.Equal(t, "foo", oqi.Events[0]["_raw"]) + } + + close(gq) + select { + case <-gqs: + case <-time.After(5 * time.Second): + t.Fatal("Generator worker did not finish in time") + } +} + +func TestGeneratorCountMinusOne(t *testing.T) { + home := filepath.Join("..", "tests", "tokens") + now, randgen := setupGenTest(t, home, 0) + + oq := make(chan *config.OutQueueItem) + s := tests.FindSampleInFile(home, "tokens") + if s == nil { + t.Fatalf("Sample tokens not found") + } + // Count=-1 means "use all lines" + gqi := &config.GenQueueItem{Count: -1, Earliest: now(), Latest: now(), Now: now(), S: s, OQ: oq, Rand: randgen, Cache: &config.CacheItem{}} + go func() { + sg := sample{} + sg.Gen(gqi) + }() + + oqi := <-oq + assert.Equal(t, len(s.Lines), len(oqi.Events)) +} + +func TestPrimeRaterSetsRater(t *testing.T) { os.Setenv("GOGEN_HOME", "..") os.Setenv("GOGEN_ALWAYS_REFRESH", "1") - os.Setenv("GOGEN_FULLCONFIG", "") - home := filepath.Join("..", "tests", "tokens") - os.Setenv("GOGEN_SAMPLES_DIR", home) - loc, _ := time.LoadLocation("Local") - source := rand.NewSource(0) - randgen := rand.New(source) - n := time.Date(2001, 10, 20, 12, 0, 0, 100000, loc) - now := func() time.Time { - return n + s := &config.Sample{ + Name: "primerater_test", + Tokens: []config.Token{ + { + Name: "ratedtoken", + Type: "rated", + RaterString: "default", + }, + { + Name: "normaltoken", + Type: "choice", + }, + }, } - // gq := make(chan *config.GenQueueItem) + PrimeRater(s) + assert.NotNil(t, s.Tokens[0].Rater, "rated token should have rater set") +} + +func TestGeneratorCache(t *testing.T) { + home := filepath.Join("..", "tests", "tokens") + now, randgen := setupGenTest(t, home, 0) + oq := make(chan *config.OutQueueItem) s := tests.FindSampleInFile(home, "token-static") if s == nil { diff --git a/generator/lua_test.go b/generator/lua_test.go index dabb53b..cc3534c 100644 --- a/generator/lua_test.go +++ b/generator/lua_test.go @@ -288,6 +288,98 @@ func TestSetTime(t *testing.T) { testLuaGen(t, s, gen, "2001-10-20 11:59:59.000100") } +func TestLuaRound(t *testing.T) { + config.ResetConfig() + + os.Setenv("GOGEN_HOME", "..") + os.Setenv("GOGEN_ALWAYS_REFRESH", "") + home := ".." + os.Setenv("GOGEN_FULLCONFIG", filepath.Join(home, "tests", "generator", "luaapi2.yml")) + + c := config.NewConfig() + s := c.FindSampleByName("roundTest") + gen := new(luagen) + runLuaGen(t, s, gen) + time.Sleep(100 * time.Millisecond) + found := false + var token config.Token + for _, tk := range gen.tokens { + if tk.Name == "rounded" { + found = true + token = tk + } + } + assert.True(t, found, "Couldn't find token 'rounded' in sample roundTest") + assert.Equal(t, "3.14", token.Replacement) +} + +func TestLuaLogInfo(t *testing.T) { + config.ResetConfig() + + os.Setenv("GOGEN_HOME", "..") + os.Setenv("GOGEN_ALWAYS_REFRESH", "") + home := ".." + os.Setenv("GOGEN_FULLCONFIG", filepath.Join(home, "tests", "generator", "luaapi2.yml")) + + c := config.NewConfig() + s := c.FindSampleByName("logInfoTest") + gen := new(luagen) + runLuaGen(t, s, gen) + time.Sleep(100 * time.Millisecond) + found := false + var token config.Token + for _, tk := range gen.tokens { + if tk.Name == "logged" { + found = true + token = tk + } + } + assert.True(t, found, "Couldn't find token 'logged' in sample logInfoTest") + assert.Equal(t, "ok", token.Replacement) +} + +func TestRemoveToken(t *testing.T) { + config.ResetConfig() + + os.Setenv("GOGEN_HOME", "..") + os.Setenv("GOGEN_ALWAYS_REFRESH", "") + home := ".." + os.Setenv("GOGEN_FULLCONFIG", filepath.Join(home, "tests", "generator", "luaapi2.yml")) + + c := config.NewConfig() + s := c.FindSampleByName("removeTokenTest") + gen := new(luagen) + runLuaGen(t, s, gen) + time.Sleep(100 * time.Millisecond) + + foundKeeper := false + foundRemover := false + for _, tk := range gen.tokens { + if tk.Name == "keeper" { + foundKeeper = true + } + if tk.Name == "remover" { + foundRemover = true + } + } + assert.True(t, foundKeeper, "Token 'keeper' should still be present") + assert.False(t, foundRemover, "Token 'remover' should have been removed") +} + +func TestSendEvent(t *testing.T) { + config.ResetConfig() + + os.Setenv("GOGEN_HOME", "..") + os.Setenv("GOGEN_ALWAYS_REFRESH", "") + home := ".." + os.Setenv("GOGEN_FULLCONFIG", filepath.Join(home, "tests", "generator", "luaapi2.yml")) + + c := config.NewConfig() + s := c.FindSampleByName("sendEventTest") + gen := new(luagen) + testLuaGen(t, s, gen, "sent via sendEvent") +} + func testLuaGen(t *testing.T, s *config.Sample, gen *luagen, expected string) { oq, err := runLuaGen(t, s, gen) timeout := make(chan bool, 1) diff --git a/gogen-api/README.md b/gogen-api/README.md index b253108..fe9e0f1 100644 --- a/gogen-api/README.md +++ b/gogen-api/README.md @@ -149,6 +149,24 @@ This script will: 4. List objects in the bucket 5. Download and verify the test file +### Testing API Unit Logic + +Focused unit tests are available for request authentication helpers and config ownership enforcement. + +From the repo root: + +```bash +make api-test +``` + +Or directly with the project virtual environment: + +```bash +/home/clint/local/src/gogen/.pyvenv/bin/python -m unittest \ + gogen-api/test_auth_utils.py \ + gogen-api/test_upsert_auth.py +``` + ## Accessing Services ### API Endpoints @@ -307,4 +325,4 @@ docker-compose logs createbuckets - Document code with clear comments - Update SUMMARY.md after completing significant features - Each AWS Lambda function should be a separate .py file in the `api` directory -- Remember that the codebase is being updated from Python 2.7 to Python 3.13 \ No newline at end of file +- Remember that the codebase is being updated from Python 2.7 to Python 3.13 diff --git a/gogen-api/api/auth.py b/gogen-api/api/auth.py new file mode 100644 index 0000000..7466eaa --- /dev/null +++ b/gogen-api/api/auth.py @@ -0,0 +1,85 @@ +import json +from cors_utils import cors_response +from github_utils import exchange_code_for_token, get_github_user +from logger import setup_logger + +logger = setup_logger(__name__) +logger.info('Loading function') + + +def respond(err, res=None): + if err: + return cors_response(400, {'error': str(err)}) + return cors_response(200, res) + + +def lambda_handler(event, context): + """ + Handle GitHub OAuth code exchange. + + Receives: { "code": "...", "state": "..." } + Returns: { "access_token": "...", "user": { "login": "...", "avatar_url": "...", "id": ... } } + """ + # Handle OPTIONS requests for CORS + if event.get('httpMethod') == 'OPTIONS': + return cors_response(200, {'message': 'OK'}) + + try: + logger.debug(f"Received event: {json.dumps(event, indent=2)}") + + # Validate request body + if 'body' not in event: + logger.error("No request body provided") + return respond("Request body is required") + + try: + body = json.loads(event['body']) + except json.JSONDecodeError as e: + logger.error(f"Invalid JSON in request body: {str(e)}") + return respond("Invalid JSON in request body") + + # Validate required fields + if 'code' not in body: + logger.error("Missing 'code' in request body") + return respond("Missing 'code' in request body") + + code = body['code'] + state = body.get('state') # State is optional but recommended for CSRF protection + + logger.info(f"Processing OAuth code exchange (state: {state})") + + # Exchange code for access token + token_data, error = exchange_code_for_token(code) + if error: + logger.error(f"Failed to exchange code for token: {error}") + return respond(error) + + access_token = token_data['access_token'] + token_type = token_data.get('token_type', 'bearer') + + # Get user information + auth_header = f"{token_type} {access_token}" + user_info, error = get_github_user(auth_header) + if error: + logger.error(f"Failed to get user info: {error}") + return respond(error) + + # Return token and user info + response = { + 'access_token': access_token, + 'token_type': token_type, + 'user': { + 'login': user_info.get('login'), + 'avatar_url': user_info.get('avatar_url'), + 'id': user_info.get('id'), + 'name': user_info.get('name'), + 'email': user_info.get('email') + } + } + + logger.info(f"OAuth exchange successful for user: {user_info.get('login')}") + return respond(None, response) + + except Exception as e: + logger.error(f"Error in lambda_handler: {str(e)}", exc_info=True) + return respond(str(e)) diff --git a/gogen-api/api/auth_utils.py b/gogen-api/api/auth_utils.py new file mode 100644 index 0000000..8525fe1 --- /dev/null +++ b/gogen-api/api/auth_utils.py @@ -0,0 +1,42 @@ +from cors_utils import cors_response +from github_utils import get_github_user +from logger import setup_logger + +logger = setup_logger(__name__) + + +def get_header(event, name): + """Return an HTTP header value using case-insensitive lookup.""" + headers = event.get('headers') or {} + target = name.lower() + + for key, value in headers.items(): + if key.lower() == target: + return value + + return None + + +def get_authenticated_username(event): + """ + Authenticate the GitHub token from the request and return the username. + + Returns: + tuple[str | None, dict | None]: (username, error_response) + """ + auth_header = get_header(event, 'Authorization') + if not auth_header: + logger.error("Authorization header not present") + return None, cors_response(401, {'error': 'Authorization header not present'}) + + user_info, error = get_github_user(auth_header) + if error: + logger.error(f"Failed to authenticate user: {error}") + return None, cors_response(401, {'error': error}) + + username = user_info.get('login') + if not username: + logger.error("Could not get username from GitHub") + return None, cors_response(401, {'error': 'Could not get username from GitHub'}) + + return username, None diff --git a/gogen-api/api/cors_utils.py b/gogen-api/api/cors_utils.py index b6b26a3..121ed14 100644 --- a/gogen-api/api/cors_utils.py +++ b/gogen-api/api/cors_utils.py @@ -16,7 +16,7 @@ def get_cors_headers() -> Dict[str, str]: return { 'Access-Control-Allow-Origin': origin, - 'Access-Control-Allow-Methods': 'GET,POST,OPTIONS', + 'Access-Control-Allow-Methods': 'GET,POST,DELETE,OPTIONS', 'Access-Control-Allow-Headers': 'Content-Type,Authorization,X-Requested-With', 'Access-Control-Allow-Credentials': 'true', 'Content-Type': 'application/json' diff --git a/gogen-api/api/delete.py b/gogen-api/api/delete.py new file mode 100644 index 0000000..0f75f2f --- /dev/null +++ b/gogen-api/api/delete.py @@ -0,0 +1,95 @@ +import json +from db_utils import get_dynamodb_client, get_table_name +from s3_utils import delete_config +from cors_utils import cors_response +from auth_utils import get_authenticated_username +from logger import setup_logger + +logger = setup_logger(__name__) +logger.info('Loading function') + + +def respond(err, res=None, status_code=400): + if err: + return cors_response(status_code, {'error': str(err)}) + return cors_response(200, res) + + +def lambda_handler(event, context): + """ + Handle configuration deletion. + + Validates GitHub token, verifies ownership, then deletes from S3 and DynamoDB. + """ + # Handle OPTIONS requests for CORS + if event.get('httpMethod') == 'OPTIONS': + return cors_response(200, {'message': 'OK'}) + + try: + logger.debug(f"Received event: {json.dumps(event, indent=2)}") + + username, auth_error = get_authenticated_username(event) + if auth_error: + return auth_error + + # Extract config name from path + path_params = event.get('pathParameters', {}) + proxy_path = path_params.get('proxy', '') + + if not proxy_path: + logger.error("No configuration path provided") + return respond("No configuration path provided") + + # Parse owner and config name from path (format: owner/configname) + path_parts = proxy_path.split('/') + if len(path_parts) < 2: + logger.error(f"Invalid path format: {proxy_path}") + return respond("Invalid path format. Expected: owner/configname") + + owner = path_parts[0] + config_name = '/'.join(path_parts[1:]) + full_config_name = f"{owner}/{config_name}" + + logger.info(f"User {username} attempting to delete config: {full_config_name}") + + # Verify ownership + if owner != username: + logger.error(f"User {username} attempted to delete config owned by {owner}") + return respond("You can only delete your own configurations", status_code=403) + + # Delete from S3 + s3_path = f"{full_config_name}.yml" + logger.info(f"Deleting from S3: {s3_path}") + s3_deleted = delete_config(s3_path) + if not s3_deleted: + logger.warning(f"Failed to delete S3 object: {s3_path} (may not exist)") + + # Delete from DynamoDB + table = get_dynamodb_client().Table(get_table_name()) + logger.info(f"Deleting from DynamoDB: {full_config_name}") + + try: + response = table.delete_item( + Key={ + 'gogen': full_config_name + }, + ReturnValues='ALL_OLD' + ) + + if 'Attributes' not in response: + logger.warning(f"Configuration not found in DynamoDB: {full_config_name}") + return respond("Configuration not found", status_code=404) + + logger.info(f"Successfully deleted configuration: {full_config_name}") + return respond(None, { + 'message': f"Successfully deleted {full_config_name}", + 'deleted': response.get('Attributes', {}) + }) + + except Exception as e: + logger.error(f"Error deleting from DynamoDB: {str(e)}") + return respond(f"Error deleting configuration: {str(e)}") + + except Exception as e: + logger.error(f"Error in lambda_handler: {str(e)}", exc_info=True) + return respond(str(e)) diff --git a/gogen-api/api/github_utils.py b/gogen-api/api/github_utils.py new file mode 100644 index 0000000..f8af266 --- /dev/null +++ b/gogen-api/api/github_utils.py @@ -0,0 +1,132 @@ +import os +import json +import http.client +from logger import setup_logger + +logger = setup_logger(__name__) + + +def validate_github_token(token): + """ + Validate the GitHub token by making a request to GitHub's API. + + Args: + token: GitHub authorization header value (e.g., "token xxx" or "Bearer xxx") + + Returns: + tuple: (is_valid: bool, error_message: str or None) + """ + headers = { + 'Authorization': token, + 'User-Agent': 'gogen lambda', + 'Content-Length': '0' + } + + logger.debug("Attempting to validate GitHub token") + conn = http.client.HTTPSConnection('api.github.com') + conn.request("GET", "/user", None, headers) + response = conn.getresponse() + + if response.status != 200: + data = response.read().decode('utf-8') + logger.error(f"GitHub token validation failed. Status: {response.status}, Reason: {response.reason}") + logger.debug(f"GitHub API response: {data}") + return False, f"Unable to authenticate user to GitHub, status: {response.status}, msg: {response.reason}" + + logger.info("GitHub token validation successful") + return True, None + + +def get_github_user(token): + """ + Get GitHub user information using the provided token. + + Args: + token: GitHub authorization header value (e.g., "token xxx" or "Bearer xxx") + + Returns: + tuple: (user_info: dict or None, error_message: str or None) + """ + headers = { + 'Authorization': token, + 'User-Agent': 'gogen lambda', + 'Accept': 'application/json' + } + + logger.debug("Fetching GitHub user info") + conn = http.client.HTTPSConnection('api.github.com') + conn.request("GET", "/user", None, headers) + response = conn.getresponse() + data = response.read().decode('utf-8') + + if response.status != 200: + logger.error(f"Failed to get GitHub user. Status: {response.status}, Reason: {response.reason}") + logger.debug(f"GitHub API response: {data}") + return None, f"Failed to get GitHub user info, status: {response.status}, msg: {response.reason}" + + try: + user_info = json.loads(data) + logger.info(f"Successfully fetched GitHub user: {user_info.get('login')}") + return user_info, None + except json.JSONDecodeError as e: + logger.error(f"Failed to parse GitHub user response: {str(e)}") + return None, "Failed to parse GitHub user response" + + +def exchange_code_for_token(code): + """ + Exchange an OAuth authorization code for an access token. + + Args: + code: OAuth authorization code from GitHub + + Returns: + tuple: (token_data: dict or None, error_message: str or None) + """ + client_id = os.environ.get('GITHUB_OAUTH_CLIENT_ID') + client_secret = os.environ.get('GITHUB_OAUTH_CLIENT_SECRET') + + if not client_id or not client_secret: + logger.error("GitHub OAuth credentials not configured") + return None, "GitHub OAuth credentials not configured" + + headers = { + 'Accept': 'application/json', + 'Content-Type': 'application/json', + 'User-Agent': 'gogen lambda' + } + + body = json.dumps({ + 'client_id': client_id, + 'client_secret': client_secret, + 'code': code + }) + + logger.debug("Exchanging OAuth code for access token") + conn = http.client.HTTPSConnection('github.com') + conn.request("POST", "/login/oauth/access_token", body, headers) + response = conn.getresponse() + data = response.read().decode('utf-8') + + if response.status != 200: + logger.error(f"OAuth token exchange failed. Status: {response.status}, Reason: {response.reason}") + logger.debug(f"GitHub OAuth response: {data}") + return None, f"OAuth token exchange failed, status: {response.status}" + + try: + token_data = json.loads(data) + + if 'error' in token_data: + error_desc = token_data.get('error_description', token_data.get('error')) + logger.error(f"OAuth error: {error_desc}") + return None, f"OAuth error: {error_desc}" + + if 'access_token' not in token_data: + logger.error("No access_token in OAuth response") + return None, "No access_token in OAuth response" + + logger.info("Successfully exchanged OAuth code for access token") + return token_data, None + except json.JSONDecodeError as e: + logger.error(f"Failed to parse OAuth response: {str(e)}") + return None, "Failed to parse OAuth response" diff --git a/gogen-api/api/my_configs.py b/gogen-api/api/my_configs.py new file mode 100644 index 0000000..945a476 --- /dev/null +++ b/gogen-api/api/my_configs.py @@ -0,0 +1,72 @@ +import json +from boto3.dynamodb.conditions import Attr +from db_utils import get_dynamodb_client, get_table_name +from cors_utils import cors_response +from auth_utils import get_authenticated_username +from logger import setup_logger + +logger = setup_logger(__name__) +logger.info('Loading function') + + +def respond(err, res=None, status_code=400): + if err: + return cors_response(status_code, {'error': str(err)}) + return cors_response(200, res) + + +def lambda_handler(event, context): + """ + List configurations owned by the authenticated user. + + Validates GitHub token, gets username, then scans DynamoDB for matching configs. + """ + # Handle OPTIONS requests for CORS + if event.get('httpMethod') == 'OPTIONS': + return cors_response(200, {'message': 'OK'}) + + try: + logger.debug(f"Received event: {json.dumps(event, indent=2)}") + + username, auth_error = get_authenticated_username(event) + if auth_error: + return auth_error + + logger.info(f"Fetching configurations for user: {username}") + + # Scan DynamoDB for configurations owned by this user + table = get_dynamodb_client().Table(get_table_name()) + + # Filter for configs where owner matches the username + filter_expression = Attr('owner').eq(username) + + items = [] + scan_kwargs = { + 'FilterExpression': filter_expression + } + + # Handle pagination + done = False + start_key = None + while not done: + if start_key: + scan_kwargs['ExclusiveStartKey'] = start_key + + response = table.scan(**scan_kwargs) + items.extend(response.get('Items', [])) + + start_key = response.get('LastEvaluatedKey') + done = start_key is None + + logger.info(f"Found {len(items)} configurations for user {username}") + + # Return the list of configurations + return respond(None, { + 'Items': items, + 'Count': len(items), + 'owner': username + }) + + except Exception as e: + logger.error(f"Error in lambda_handler: {str(e)}", exc_info=True) + return respond(str(e)) diff --git a/gogen-api/api/upsert.py b/gogen-api/api/upsert.py index 09b58de..92a0519 100644 --- a/gogen-api/api/upsert.py +++ b/gogen-api/api/upsert.py @@ -1,9 +1,8 @@ import json -import http.client -from boto3.dynamodb.conditions import Key, Attr from db_utils import get_dynamodb_client, get_table_name from s3_utils import upload_config from cors_utils import cors_response +from auth_utils import get_authenticated_username from logger import setup_logger logger = setup_logger(__name__) @@ -16,31 +15,6 @@ def respond(err, res=None): return cors_response(200, res) -def validate_github_token(token): - """ - Validate the GitHub token by making a request to GitHub's API - """ - headers = { - 'Authorization': token, - 'User-Agent': 'gogen lambda', - 'Content-Length': '0' - } - - logger.debug("Attempting to validate GitHub token") - conn = http.client.HTTPSConnection('api.github.com') - conn.request("GET", "/user", None, headers) - response = conn.getresponse() - - if response.status != 200: - data = response.read().decode('utf-8') - logger.error(f"GitHub token validation failed. Status: {response.status}, Reason: {response.reason}") - logger.debug(f"GitHub API response: {data}") - return False, f"Unable to authenticate user to GitHub, status: {response.status}, msg: {response.reason}" - - logger.info("GitHub token validation successful") - return True, None - - def lambda_handler(event, context): # Handle OPTIONS requests for CORS if event.get('httpMethod') == 'OPTIONS': @@ -60,15 +34,9 @@ def lambda_handler(event, context): logger.error(f"Invalid JSON in request body: {str(e)}") return respond("Invalid JSON in request body") - # Validate GitHub authorization - if 'headers' not in event or 'Authorization' not in event['headers']: - logger.error("Authorization header not present") - return respond("Authorization header not present") - - # Validate GitHub token - is_valid, error_msg = validate_github_token(event['headers']['Authorization']) - if not is_valid: - return respond(error_msg) + username, auth_error = get_authenticated_username(event) + if auth_error: + return auth_error # Validate and clean request body validated_body = {} @@ -84,9 +52,9 @@ def lambda_handler(event, context): if 'config' in validated_body: config_content = validated_body['config'] - # Create S3 path in the format username/sample.yml - if 'owner' in validated_body and 'name' in validated_body: - s3_path = f"{validated_body['owner']}/{validated_body['name']}.yml" + if 'name' in validated_body: + validated_body['owner'] = username + s3_path = f"{username}/{validated_body['name']}.yml" # Upload config to S3 logger.info(f"Uploading config to S3 at path: {s3_path}") @@ -102,12 +70,15 @@ def lambda_handler(event, context): # Add S3 path to DynamoDB item validated_body['s3Path'] = s3_path - + + # Set the primary key (gogen = owner/name) + validated_body['gogen'] = f"{username}/{validated_body['name']}" + # Remove gistID if present (for migration) validated_body.pop('gistID', None) else: - logger.error("Owner or name missing in request body") - return respond("Owner and name are required fields") + logger.error("Name missing in request body") + return respond("Name is a required field") else: logger.warning("No config found in request body") @@ -126,4 +97,4 @@ def lambda_handler(event, context): except Exception as e: logger.error(f"Error in lambda_handler: {str(e)}", exc_info=True) - return respond(e) \ No newline at end of file + return respond(e) diff --git a/gogen-api/deploy_lambdas.sh b/gogen-api/deploy_lambdas.sh index b1db728..1adfe08 100755 --- a/gogen-api/deploy_lambdas.sh +++ b/gogen-api/deploy_lambdas.sh @@ -55,6 +55,17 @@ if [ -z "$ROLE_ARN" ]; then fi echo "Using role ARN from environment: $ROLE_ARN" +# Expect GitHub OAuth credentials to be set as environment variables +if [ -z "$GITHUB_OAUTH_CLIENT_ID" ]; then + echo "Error: GITHUB_OAUTH_CLIENT_ID environment variable is not set." >&2 + exit 1 +fi +if [ -z "$GITHUB_OAUTH_CLIENT_SECRET" ]; then + echo "Error: GITHUB_OAUTH_CLIENT_SECRET environment variable is not set." >&2 + exit 1 +fi +echo "GitHub OAuth credentials configured" + # Create build directory if it doesn't exist mkdir -p $BUILD_DIR @@ -184,6 +195,8 @@ echo " LambdaRoleArn=${ROLE_ARN}" echo " CertificateArn=${CERT_ARN}" echo " ProdTableName=gogen" echo " StagingTableName=gogen-staging" +echo " GitHubOAuthClientId=${GITHUB_OAUTH_CLIENT_ID}" +echo " GitHubOAuthClientSecret=" sam deploy \ --stack-name "gogen-api-${ENVIRONMENT}" \ @@ -194,6 +207,8 @@ sam deploy \ ParameterKey=CertificateArn,ParameterValue=${CERT_ARN} \ ParameterKey=ProdTableName,ParameterValue=gogen \ ParameterKey=StagingTableName,ParameterValue=gogen-staging \ + ParameterKey=GitHubOAuthClientId,ParameterValue=${GITHUB_OAUTH_CLIENT_ID} \ + ParameterKey=GitHubOAuthClientSecret,ParameterValue=${GITHUB_OAUTH_CLIENT_SECRET} \ --capabilities CAPABILITY_IAM CAPABILITY_NAMED_IAM \ --no-confirm-changeset \ --no-fail-on-empty-changeset diff --git a/gogen-api/env.json.example b/gogen-api/env.json.example new file mode 100644 index 0000000..44cf46a --- /dev/null +++ b/gogen-api/env.json.example @@ -0,0 +1,6 @@ +{ + "Parameters": { + "GITHUB_OAUTH_CLIENT_ID": "your_github_oauth_client_id_here", + "GITHUB_OAUTH_CLIENT_SECRET": "your_github_oauth_client_secret_here" + } +} diff --git a/gogen-api/iam_policy.json b/gogen-api/iam_policy.json index 40a93b4..2572512 100644 --- a/gogen-api/iam_policy.json +++ b/gogen-api/iam_policy.json @@ -43,7 +43,9 @@ "s3:ListBucket", "s3:GetBucketLocation", "iam:PassRole", - "iam:GetRole" + "iam:GetRole", + "iam:AttachRolePolicy", + "iam:DetachRolePolicy" ], "Resource": [ "arn:aws:cloudformation:*:*:stack/gogen-api-prod/*", diff --git a/gogen-api/start_dev.sh b/gogen-api/start_dev.sh index f328e0d..036ca93 100755 --- a/gogen-api/start_dev.sh +++ b/gogen-api/start_dev.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/bin/bash # Determine script directory and project root SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" @@ -18,6 +18,30 @@ else echo "Consider creating it with: python3 -m venv $VENV_PATH" fi +# Check for env.json (OAuth credentials) +ENV_FILE="$SCRIPT_DIR/env.json" +SAM_ENV_ARGS="" +if [ -f "$ENV_FILE" ]; then + echo "Found env.json - OAuth credentials will be loaded" + SAM_ENV_ARGS="--env-vars $ENV_FILE" +else + echo "" + echo "==========================================" + echo "WARNING: env.json not found!" + echo "GitHub OAuth login will not work locally." + echo "" + echo "To enable OAuth, create env.json from the template:" + echo " cp env.json.example env.json" + echo " # Then edit env.json with your GitHub OAuth credentials" + echo "" + echo "Get OAuth credentials by creating a GitHub OAuth App at:" + echo " https://github.com/settings/developers" + echo " - Homepage URL: http://localhost:3000" + echo " - Callback URL: http://localhost:3000/auth/callback" + echo "==========================================" + echo "" +fi + # Start Docker containers echo "Starting Docker containers..." cd "$SCRIPT_DIR" @@ -36,7 +60,7 @@ run_test_commands() { echo "Running test gogen commands to validate API..." GOGEN_APIURL=http://localhost:4000 gogen -c "$PROJECT_ROOT/examples/tutorial/tutorial1.yml" push tutorial1 GOGEN_APIURL=http://localhost:4000 gogen -c coccyx/tutorial1 config - + echo "Test commands completed." } @@ -49,8 +73,8 @@ sam build run_test_commands & TEST_COMMANDS_PID=$! -# Start SAM local in foreground -sam local start-api --host 0.0.0.0 --port 4000 --warm-containers EAGER --docker-network lambda-local +# Start SAM local in foreground (with env vars if available) +sam local start-api --host 0.0.0.0 --port 4000 --warm-containers EAGER --docker-network lambda-local $SAM_ENV_ARGS # Trap Ctrl+C and call cleanup cleanup() { @@ -66,4 +90,4 @@ cleanup() { trap cleanup INT -cleanup \ No newline at end of file +cleanup diff --git a/gogen-api/template.yaml b/gogen-api/template.yaml index c63a9bd..9af8e7b 100644 --- a/gogen-api/template.yaml +++ b/gogen-api/template.yaml @@ -80,10 +80,11 @@ Resources: version: '1.0' x-amazon-apigateway-cors: allowOrigins: - - !If [IsProduction, "'https://gogen.io'", "'https://staging.gogen.io'"] + - !If [IsProduction, "https://gogen.io", "https://staging.gogen.io"] allowMethods: - GET - POST + - DELETE - OPTIONS allowHeaders: - Content-Type @@ -213,6 +214,26 @@ Resources: httpMethod: POST type: aws_proxy uri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${GetFunction.Arn}/invocations" + options: + summary: CORS preflight + responses: + '200': + description: CORS headers + headers: + Access-Control-Allow-Origin: { schema: { type: string } } + Access-Control-Allow-Methods: { schema: { type: string } } + Access-Control-Allow-Headers: { schema: { type: string } } + x-amazon-apigateway-integration: + type: mock + requestTemplates: + application/json: '{"statusCode": 200}' + responses: + default: + statusCode: '200' + responseParameters: + method.response.header.Access-Control-Allow-Origin: !If [IsProduction, "'https://gogen.io'", "'https://staging.gogen.io'"] + method.response.header.Access-Control-Allow-Methods: "'GET,POST,DELETE,OPTIONS'" + method.response.header.Access-Control-Allow-Headers: "'Content-Type,Authorization,X-Requested-With'" /v1/list: get: responses: {} @@ -220,6 +241,26 @@ Resources: httpMethod: POST type: aws_proxy uri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${ListFunction.Arn}/invocations" + options: + summary: CORS preflight + responses: + '200': + description: CORS headers + headers: + Access-Control-Allow-Origin: { schema: { type: string } } + Access-Control-Allow-Methods: { schema: { type: string } } + Access-Control-Allow-Headers: { schema: { type: string } } + x-amazon-apigateway-integration: + type: mock + requestTemplates: + application/json: '{"statusCode": 200}' + responses: + default: + statusCode: '200' + responseParameters: + method.response.header.Access-Control-Allow-Origin: !If [IsProduction, "'https://gogen.io'", "'https://staging.gogen.io'"] + method.response.header.Access-Control-Allow-Methods: "'GET,POST,DELETE,OPTIONS'" + method.response.header.Access-Control-Allow-Headers: "'Content-Type,Authorization,X-Requested-With'" /v1/search: get: responses: {} @@ -227,6 +268,26 @@ Resources: httpMethod: POST type: aws_proxy uri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${SearchFunction.Arn}/invocations" + options: + summary: CORS preflight + responses: + '200': + description: CORS headers + headers: + Access-Control-Allow-Origin: { schema: { type: string } } + Access-Control-Allow-Methods: { schema: { type: string } } + Access-Control-Allow-Headers: { schema: { type: string } } + x-amazon-apigateway-integration: + type: mock + requestTemplates: + application/json: '{"statusCode": 200}' + responses: + default: + statusCode: '200' + responseParameters: + method.response.header.Access-Control-Allow-Origin: !If [IsProduction, "'https://gogen.io'", "'https://staging.gogen.io'"] + method.response.header.Access-Control-Allow-Methods: "'GET,POST,DELETE,OPTIONS'" + method.response.header.Access-Control-Allow-Headers: "'Content-Type,Authorization,X-Requested-With'" /v1/upsert: post: responses: {} @@ -234,6 +295,107 @@ Resources: httpMethod: POST type: aws_proxy uri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${UpsertFunction.Arn}/invocations" + options: + summary: CORS preflight + responses: + '200': + description: CORS headers + headers: + Access-Control-Allow-Origin: { schema: { type: string } } + Access-Control-Allow-Methods: { schema: { type: string } } + Access-Control-Allow-Headers: { schema: { type: string } } + x-amazon-apigateway-integration: + type: mock + requestTemplates: + application/json: '{"statusCode": 200}' + responses: + default: + statusCode: '200' + responseParameters: + method.response.header.Access-Control-Allow-Origin: !If [IsProduction, "'https://gogen.io'", "'https://staging.gogen.io'"] + method.response.header.Access-Control-Allow-Methods: "'GET,POST,DELETE,OPTIONS'" + method.response.header.Access-Control-Allow-Headers: "'Content-Type,Authorization,X-Requested-With'" + /v1/auth/github: + post: + responses: {} + x-amazon-apigateway-integration: + httpMethod: POST + type: aws_proxy + uri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${AuthFunction.Arn}/invocations" + options: + summary: CORS preflight + responses: + '200': + description: CORS headers + headers: + Access-Control-Allow-Origin: { schema: { type: string } } + Access-Control-Allow-Methods: { schema: { type: string } } + Access-Control-Allow-Headers: { schema: { type: string } } + x-amazon-apigateway-integration: + type: mock + requestTemplates: + application/json: '{"statusCode": 200}' + responses: + default: + statusCode: '200' + responseParameters: + method.response.header.Access-Control-Allow-Origin: !If [IsProduction, "'https://gogen.io'", "'https://staging.gogen.io'"] + method.response.header.Access-Control-Allow-Methods: "'GET,POST,DELETE,OPTIONS'" + method.response.header.Access-Control-Allow-Headers: "'Content-Type,Authorization,X-Requested-With'" + /v1/delete/{proxy+}: + delete: + responses: {} + x-amazon-apigateway-integration: + httpMethod: POST + type: aws_proxy + uri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${DeleteFunction.Arn}/invocations" + options: + summary: CORS preflight + responses: + '200': + description: CORS headers + headers: + Access-Control-Allow-Origin: { schema: { type: string } } + Access-Control-Allow-Methods: { schema: { type: string } } + Access-Control-Allow-Headers: { schema: { type: string } } + x-amazon-apigateway-integration: + type: mock + requestTemplates: + application/json: '{"statusCode": 200}' + responses: + default: + statusCode: '200' + responseParameters: + method.response.header.Access-Control-Allow-Origin: !If [IsProduction, "'https://gogen.io'", "'https://staging.gogen.io'"] + method.response.header.Access-Control-Allow-Methods: "'GET,POST,DELETE,OPTIONS'" + method.response.header.Access-Control-Allow-Headers: "'Content-Type,Authorization,X-Requested-With'" + /v1/my-configs: + get: + responses: {} + x-amazon-apigateway-integration: + httpMethod: POST + type: aws_proxy + uri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${MyConfigsFunction.Arn}/invocations" + options: + summary: CORS preflight + responses: + '200': + description: CORS headers + headers: + Access-Control-Allow-Origin: { schema: { type: string } } + Access-Control-Allow-Methods: { schema: { type: string } } + Access-Control-Allow-Headers: { schema: { type: string } } + x-amazon-apigateway-integration: + type: mock + requestTemplates: + application/json: '{"statusCode": 200}' + responses: + default: + statusCode: '200' + responseParameters: + method.response.header.Access-Control-Allow-Origin: !If [IsProduction, "'https://gogen.io'", "'https://staging.gogen.io'"] + method.response.header.Access-Control-Allow-Methods: "'GET,POST,DELETE,OPTIONS'" + method.response.header.Access-Control-Allow-Headers: "'Content-Type,Authorization,X-Requested-With'" Domain: DomainName: !If - IsProduction @@ -311,6 +473,58 @@ Resources: ENVIRONMENT: !Ref Environment DYNAMODB_TABLE_NAME: !If [IsProduction, !Ref ProdTableName, !Ref StagingTableName] + AuthFunction: + Type: AWS::Serverless::Function + Metadata: + Dockerfile: Dockerfile + DockerContext: . + DockerTag: python3.13-v1 + Properties: + CodeUri: ./api + Handler: auth.lambda_handler + Runtime: python3.13 + Timeout: 10 + Role: !Ref LambdaRoleArn + Environment: + Variables: + ENVIRONMENT: !Ref Environment + GITHUB_OAUTH_CLIENT_ID: !Ref GitHubOAuthClientId + GITHUB_OAUTH_CLIENT_SECRET: !Ref GitHubOAuthClientSecret + + DeleteFunction: + Type: AWS::Serverless::Function + Metadata: + Dockerfile: Dockerfile + DockerContext: . + DockerTag: python3.13-v1 + Properties: + CodeUri: ./api + Handler: delete.lambda_handler + Runtime: python3.13 + Timeout: 10 + Role: !Ref LambdaRoleArn + Environment: + Variables: + ENVIRONMENT: !Ref Environment + DYNAMODB_TABLE_NAME: !If [IsProduction, !Ref ProdTableName, !Ref StagingTableName] + + MyConfigsFunction: + Type: AWS::Serverless::Function + Metadata: + Dockerfile: Dockerfile + DockerContext: . + DockerTag: python3.13-v1 + Properties: + CodeUri: ./api + Handler: my_configs.lambda_handler + Runtime: python3.13 + Timeout: 10 + Role: !Ref LambdaRoleArn + Environment: + Variables: + ENVIRONMENT: !Ref Environment + DYNAMODB_TABLE_NAME: !If [IsProduction, !Ref ProdTableName, !Ref StagingTableName] + ListFunctionApiGatewayInvokePermission: Type: AWS::Lambda::Permission Properties: @@ -343,6 +557,30 @@ Resources: Principal: apigateway.amazonaws.com SourceArn: !Sub arn:aws:execute-api:${AWS::Region}:${AWS::AccountId}:${GoGenApi}/*/GET/v1/get/* + AuthFunctionApiGatewayInvokePermission: + Type: AWS::Lambda::Permission + Properties: + Action: lambda:InvokeFunction + FunctionName: !Ref AuthFunction + Principal: apigateway.amazonaws.com + SourceArn: !Sub arn:aws:execute-api:${AWS::Region}:${AWS::AccountId}:${GoGenApi}/*/POST/v1/auth/github + + DeleteFunctionApiGatewayInvokePermission: + Type: AWS::Lambda::Permission + Properties: + Action: lambda:InvokeFunction + FunctionName: !Ref DeleteFunction + Principal: apigateway.amazonaws.com + SourceArn: !Sub arn:aws:execute-api:${AWS::Region}:${AWS::AccountId}:${GoGenApi}/*/DELETE/v1/delete/* + + MyConfigsFunctionApiGatewayInvokePermission: + Type: AWS::Lambda::Permission + Properties: + Action: lambda:InvokeFunction + FunctionName: !Ref MyConfigsFunction + Principal: apigateway.amazonaws.com + SourceArn: !Sub arn:aws:execute-api:${AWS::Region}:${AWS::AccountId}:${GoGenApi}/*/GET/v1/my-configs + Conditions: IsStagingEnvironment: !Equals - !Ref Environment @@ -378,6 +616,15 @@ Parameters: Default: gogen-staging Description: Name of the existing DynamoDB table for Staging + GitHubOAuthClientId: + Type: String + Description: GitHub OAuth App Client ID + + GitHubOAuthClientSecret: + Type: String + NoEcho: true + Description: GitHub OAuth App Client Secret + Outputs: ApiURL: Description: API Gateway endpoint URL diff --git a/gogen-api/test_auth_utils.py b/gogen-api/test_auth_utils.py new file mode 100644 index 0000000..2371a12 --- /dev/null +++ b/gogen-api/test_auth_utils.py @@ -0,0 +1,43 @@ +import sys +import unittest +from pathlib import Path +from unittest.mock import patch + +sys.path.insert(0, str(Path(__file__).resolve().parent / 'api')) + +import auth_utils # noqa: E402 + + +class AuthUtilsTest(unittest.TestCase): + def test_get_header_is_case_insensitive(self): + event = { + 'headers': { + 'authorization': 'token abc123', + } + } + + self.assertEqual(auth_utils.get_header(event, 'Authorization'), 'token abc123') + + @patch('auth_utils.get_github_user') + def test_get_authenticated_username_returns_username(self, mock_get_github_user): + mock_get_github_user.return_value = ({'login': 'clint'}, None) + event = {'headers': {'Authorization': 'token abc123'}} + + username, error = auth_utils.get_authenticated_username(event) + + self.assertEqual(username, 'clint') + self.assertIsNone(error) + mock_get_github_user.assert_called_once_with('token abc123') + + @patch('auth_utils.get_github_user') + def test_get_authenticated_username_returns_401_without_authorization_header(self, mock_get_github_user): + username, error = auth_utils.get_authenticated_username({'headers': {}}) + + self.assertIsNone(username) + self.assertEqual(error['statusCode'], '401') + self.assertIn('Authorization header not present', error['body']) + mock_get_github_user.assert_not_called() + + +if __name__ == '__main__': + unittest.main() diff --git a/gogen-api/test_upsert_auth.py b/gogen-api/test_upsert_auth.py new file mode 100644 index 0000000..beef1f6 --- /dev/null +++ b/gogen-api/test_upsert_auth.py @@ -0,0 +1,83 @@ +import json +import sys +import unittest +from pathlib import Path +from types import ModuleType +from unittest.mock import Mock, patch + +sys.path.insert(0, str(Path(__file__).resolve().parent / 'api')) + +boto3_module = ModuleType('boto3') +boto3_module.resource = Mock() +sys.modules.setdefault('boto3', boto3_module) + +botocore_module = ModuleType('botocore') +botocore_config_module = ModuleType('botocore.config') + + +class Config: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + +botocore_config_module.Config = Config +sys.modules.setdefault('botocore', botocore_module) +sys.modules.setdefault('botocore.config', botocore_config_module) + +import upsert # noqa: E402 + + +class UpsertAuthTest(unittest.TestCase): + @patch('upsert.get_table_name', return_value='gogen') + @patch('upsert.get_dynamodb_client') + @patch('upsert.upload_config', return_value=True) + @patch('upsert.get_authenticated_username', return_value=('actual-user', None)) + def test_upsert_uses_authenticated_username_for_owner( + self, + mock_get_authenticated_username, + mock_upload_config, + mock_get_dynamodb_client, + _mock_get_table_name, + ): + table = Mock() + table.put_item.return_value = {'ResponseMetadata': {'HTTPStatusCode': 200}} + mock_get_dynamodb_client.return_value.Table.return_value = table + + event = { + 'headers': {'authorization': 'token abc123'}, + 'body': json.dumps({ + 'owner': 'spoofed-user', + 'name': 'sample', + 'description': 'demo', + 'config': 'global: {}' + }) + } + + response = upsert.lambda_handler(event, None) + + self.assertEqual(response['statusCode'], '200') + mock_get_authenticated_username.assert_called_once_with(event) + mock_upload_config.assert_called_once_with('actual-user/sample.yml', 'global: {}') + table.put_item.assert_called_once() + + stored_item = table.put_item.call_args.kwargs['Item'] + self.assertEqual(stored_item['owner'], 'actual-user') + self.assertEqual(stored_item['gogen'], 'actual-user/sample') + self.assertEqual(stored_item['s3Path'], 'actual-user/sample.yml') + + @patch('upsert.get_authenticated_username') + def test_upsert_returns_auth_error_response(self, mock_get_authenticated_username): + mock_get_authenticated_username.return_value = ( + None, + {'statusCode': '401', 'body': json.dumps({'error': 'Authorization header not present'})} + ) + + response = upsert.lambda_handler({'headers': {}, 'body': '{}'}, None) + + self.assertEqual(response['statusCode'], '401') + self.assertIn('Authorization header not present', response['body']) + + +if __name__ == '__main__': + unittest.main() diff --git a/internal/config.go b/internal/config.go index 04c2da7..a2d4209 100644 --- a/internal/config.go +++ b/internal/config.go @@ -4,25 +4,18 @@ import ( "bufio" "bytes" "encoding/csv" - "encoding/json" "fmt" - "io/ioutil" - "net/http" "os" "path/filepath" "runtime/debug" - "sort" "strconv" "strings" - "sync" "time" log "github.com/coccyx/gogen/logger" "github.com/coccyx/gogen/template" "github.com/coccyx/timeparser" "github.com/kr/pretty" - lua "github.com/yuin/gopher-lua" - yaml "gopkg.in/yaml.v2" ) // Config is a struct representing a Singleton which contains a copy of the running config @@ -93,9 +86,16 @@ type Share interface { } var instance *Config -var once sync.Once var share Share +// setDefault sets *ptr to defaultVal if *ptr is the zero value for its type. +func setDefault[T comparable](ptr *T, defaultVal T) { + var zero T + if *ptr == zero { + *ptr = defaultVal + } +} + func getConfig() *Config { if instance == nil { instance = &Config{initialized: false} @@ -167,7 +167,7 @@ func BuildConfig(cc ConfigConfig) *Config { if len(cc.FullConfig) > 0 { cc.FullConfig = os.ExpandEnv(cc.FullConfig) - if cc.FullConfig[0:4] == "http" { + if strings.HasPrefix(cc.FullConfig, "http") { log.Infof("Fetching config from '%s'", cc.FullConfig) if err := c.parseWebConfig(&c, cc.FullConfig); err != nil { log.Panic(err) @@ -180,7 +180,6 @@ func BuildConfig(cc ConfigConfig) *Config { if err := c.parseFileConfig(&c, cc.FullConfig); err != nil { log.Panic(err) } - // if filepath.Dir(cc.FullConfig) != "." && !strings.Contains(cc.FullConfig, "tests") { if !strings.Contains(cc.FullConfig, "tests") { c.Global.SamplesDir = append(c.Global.SamplesDir, filepath.Dir(cc.FullConfig)) } @@ -195,54 +194,24 @@ func BuildConfig(cc ConfigConfig) *Config { } } } - if c.Global.ROTInterval == 0 { - c.Global.ROTInterval = defaultROTInterval - } + setDefault(&c.Global.ROTInterval, defaultROTInterval) // Don't set defaults if we're exporting if !cc.Export { - // // Setup defaults for global - // - if c.Global.GeneratorWorkers == 0 { - c.Global.GeneratorWorkers = defaultGeneratorWorkers - } - if c.Global.OutputWorkers == 0 { - c.Global.OutputWorkers = defaultOutputWorkers - } - if c.Global.GeneratorQueueLength == 0 { - c.Global.GeneratorQueueLength = defaultGenQueueLength - } - if c.Global.OutputQueueLength == 0 { - c.Global.OutputQueueLength = defaultOutQueueLength - } - if c.Global.Output.Outputter == "" { - c.Global.Output.Outputter = defaultOutputter - } - if c.Global.Output.OutputTemplate == "" { - c.Global.Output.OutputTemplate = defaultOutputTemplate - } + setDefault(&c.Global.GeneratorWorkers, defaultGeneratorWorkers) + setDefault(&c.Global.OutputWorkers, defaultOutputWorkers) + setDefault(&c.Global.GeneratorQueueLength, defaultGenQueueLength) + setDefault(&c.Global.OutputQueueLength, defaultOutQueueLength) + setDefault(&c.Global.Output.Outputter, defaultOutputter) + setDefault(&c.Global.Output.OutputTemplate, defaultOutputTemplate) - // // Setup defaults for outputs - // - if c.Global.Output.FileName == "" { - c.Global.Output.FileName = defaultFileName - } - if c.Global.Output.BackupFiles == 0 { - c.Global.Output.BackupFiles = defaultBackupFiles - } - if c.Global.Output.MaxBytes == 0 { - c.Global.Output.MaxBytes = defaultMaxBytes - } - if c.Global.Output.BufferBytes == 0 { - c.Global.Output.BufferBytes = defaultBufferBytes - } - if c.Global.Output.Timeout == time.Duration(0) { - c.Global.Output.Timeout = defaultTimeout - } - if c.Global.Output.Topic == "" { - c.Global.Output.Topic = defaultTopic - } + setDefault(&c.Global.Output.FileName, defaultFileName) + setDefault(&c.Global.Output.BackupFiles, defaultBackupFiles) + setDefault(&c.Global.Output.MaxBytes, defaultMaxBytes) + setDefault(&c.Global.Output.BufferBytes, defaultBufferBytes) + setDefault(&c.Global.Output.Timeout, defaultTimeout) + setDefault(&c.Global.Output.Topic, defaultTopic) if len(c.Global.Output.Headers) == 0 { c.Global.Output.Headers = map[string]string{ "Content-Type": "application/json", @@ -261,60 +230,25 @@ func BuildConfig(cc ConfigConfig) *Config { c.Templates = append(c.Templates, templates...) for _, t := range c.Templates { if len(t.Header) > 0 { - _ = template.New(t.Name+"_header", t.Header) + if err := template.New(t.Name+"_header", t.Header); err != nil { + log.Errorf("Error creating header template for '%s': %v", t.Name, err) + } + } + if err := template.New(t.Name+"_row", t.Row); err != nil { + log.Errorf("Error creating row template for '%s': %v", t.Name, err) } - _ = template.New(t.Name+"_row", t.Row) if len(t.Footer) > 0 { - _ = template.New(t.Name+"_footer", t.Footer) + if err := template.New(t.Name+"_footer", t.Footer); err != nil { + log.Errorf("Error creating footer template for '%s': %v", t.Name, err) + } } } } if len(cc.FullConfig) == 0 { - // Read all templates in $GOGEN_HOME/config/templates - fullPath := filepath.Join(cc.ConfigDir, "templates") - acceptableExtensions := map[string]bool{".yml": true, ".yaml": true, ".json": true} - c.walkPath(fullPath, acceptableExtensions, func(innerPath string) error { - t := new(Template) - - if err := c.parseFileConfig(&t, innerPath); err != nil { - log.Errorf("Error parsing config %s: %s", innerPath, err) - return err - } - - c.Templates = append(c.Templates, t) - return nil - }) - - // Read all raters in $GOGEN_HOME/config/raters - fullPath = filepath.Join(cc.ConfigDir, "raters") - acceptableExtensions = map[string]bool{".yml": true, ".yaml": true, ".json": true} - c.walkPath(fullPath, acceptableExtensions, func(innerPath string) error { - var r RaterConfig - - if err := c.parseFileConfig(&r, innerPath); err != nil { - log.Errorf("Error parsing config %s: %s", innerPath, err) - return err - } - - c.Raters = append(c.Raters, &r) - return nil - }) - - // Read all generators in $GOGEN_HOME/config/generators - fullPath = filepath.Join(cc.ConfigDir, "generators") - acceptableExtensions = map[string]bool{".yml": true, ".yaml": true, ".json": true} - c.walkPath(fullPath, acceptableExtensions, func(innerPath string) error { - var g GeneratorConfig - - if err := c.parseFileConfig(&g, innerPath); err != nil { - log.Errorf("Error parsing config %s: %s", innerPath, err) - return err - } - - c.Generators = append(c.Generators, &g) - return nil - }) + loadConfigDir(c, cc.ConfigDir, "templates", &c.Templates) + loadConfigDir(c, cc.ConfigDir, "raters", &c.Raters) + loadConfigDir(c, cc.ConfigDir, "generators", &c.Generators) c.readSamplesDir(cc.SamplesDir) } @@ -406,8 +340,8 @@ func BuildConfig(cc ConfigConfig) *Config { for _, m := range c.Mix { cc := ConfigConfig{FullConfig: m.Sample, Export: false} var nc *Config - acceptableExtensions := map[string]bool{".yml": true, ".yaml": true, ".json": true, ".sample": true, ".csv": true} - if _, ok := acceptableExtensions[filepath.Ext(m.Sample)]; ok { + mixExtensions := map[string]bool{".yml": true, ".yaml": true, ".json": true, ".sample": true, ".csv": true} + if _, ok := mixExtensions[filepath.Ext(m.Sample)]; ok { nc = BuildConfig(cc) c.mergeMixConfig(nc, m) } else { @@ -528,8 +462,7 @@ func (c *Config) readSamplesDir(samplesDir string) { }) // Read all YAML & JSON samples in $GOGEN_HOME/config/samples directory - acceptableExtensions = map[string]bool{".yml": true, ".yaml": true, ".json": true} - c.walkPath(samplesDir, acceptableExtensions, func(innerPath string) error { + c.walkPath(samplesDir, configExtensions, func(innerPath string) error { if c.cc.FullConfig != innerPath { log.Debugf("Loading YAML sample '%s'", innerPath) s := Sample{} @@ -546,414 +479,6 @@ func (c *Config) readSamplesDir(samplesDir string) { }) } -// validate takes a sample and checks against any rules which may cause the configuration to be invalid. -// This hopefully centralizes logic for valid configs, disabling any samples which are not valid and -// preventing this logic from sprawling all over the code base. -// Also finds any references from tokens to other samples and -// updates the token to point to the sample data -// Also fixes up any additional things which are needed, like weighted choice string -// string map to the randutil Choice struct -func (c *Config) validate(s *Sample) { - if s.realSample { - s.Buf = &c.Buf - if s.Generator == "" { - s.Generator = defaultGenerator - } - if len(s.Name) == 0 { - s.Disabled = true - s.realSample = false - } else if len(s.Lines) == 0 && (s.Generator == "sample" || s.Generator == "replay") { - s.Disabled = true - s.realSample = false - log.Errorf("Disabling sample '%s', no lines in sample", s.Name) - } else { - s.realSample = true - } - - // Put the output into the sample for convenience - s.Output = &c.Global.Output - - // Setup defaults - if s.Earliest == "" { - s.Earliest = defaultEarliest - } - if s.Latest == "" { - s.Latest = defaultLatest - } - if s.RandomizeEvents == false { - s.RandomizeEvents = defaultRandomizeEvents - } - if s.Field == "" { - s.Field = DefaultField - } - if s.RaterString == "" { - s.RaterString = defaultRater - } - - ParseBeginEnd(s) - - // - // Parse earliest and latest as relative times - // - // Cache a time so we can get a delta for parsed begin, end, earliest and latest - n := time.Now() - now := func() time.Time { - return n - } - if p, err := timeparser.TimeParserNow(s.Earliest, now); err != nil { - log.Errorf("Error parsing earliest time '%s' for sample '%s', using Now", s.Earliest, s.Name) - s.EarliestParsed = time.Duration(0) - } else { - s.EarliestParsed = n.Sub(p) * -1 - } - if p, err := timeparser.TimeParserNow(s.Latest, now); err != nil { - log.Errorf("Error parsing latest time '%s' for sample '%s', using Now", s.Latest, s.Name) - s.LatestParsed = time.Duration(0) - } else { - s.LatestParsed = n.Sub(p) * -1 - } - - // log.Debugf("Resolving '%s'", s.Name) - for i := 0; i < len(s.Tokens); i++ { - if s.Tokens[i].Type == "rated" && s.Tokens[i].RaterString == "" { - s.Tokens[i].RaterString = "default" - } - if s.Tokens[i].Field == "" { - s.Tokens[i].Field = s.Field - } - // If format is template, then create a default token of $tokenname$ - if s.Tokens[i].Format == "template" && s.Tokens[i].Token == "" { - s.Tokens[i].Token = "$" + s.Tokens[i].Name + "$" - } - s.Tokens[i].Parent = s - s.Tokens[i].luaState = new(lua.LTable) - // log.Debugf("Resolving token '%s' for sample '%s'", s.Tokens[i].Name, s.Name) - for j := 0; j < len(c.Samples); j++ { - if s.Tokens[i].SampleString != "" && s.Tokens[i].SampleString == c.Samples[j].Name { - log.Debugf("Resolving sample '%s' for token '%s'", c.Samples[j].Name, s.Tokens[i].Name) - s.Tokens[i].Sample = c.Samples[j] - // See if a field exists other than _raw, if so, FieldChoice - otherfield := false - if len(c.Samples[j].Lines) > 0 { - for k := range c.Samples[j].Lines[0] { - if k != "_raw" { - otherfield = true - break - } - } - } - if otherfield { - // If we're a structured sample and we contain the field "_weight", then we create a weighted choice struct - // Otherwise we're a fieldChoice - _, ok := c.Samples[j].Lines[0]["_weight"] - _, ok2 := c.Samples[j].Lines[0][s.Tokens[i].SrcField] - if ok && ok2 { - for _, line := range c.Samples[j].Lines { - weight, err := strconv.Atoi(line["_weight"]) - if err != nil { - weight = 0 - } - s.Tokens[i].WeightedChoice = append(s.Tokens[i].WeightedChoice, WeightedChoice{Weight: weight, Choice: line[s.Tokens[i].SrcField]}) - } - } else { - s.Tokens[i].FieldChoice = c.Samples[j].Lines - } - } else { - // s.Tokens[i].WeightedChoice = c.Samples[j].Lines - temp := make([]string, 0, len(c.Samples[j].Lines)) - for _, line := range c.Samples[j].Lines { - if _, ok := line["_raw"]; ok { - if len(line["_raw"]) > 0 { - temp = append(temp, line["_raw"]) - } - } - } - s.Tokens[i].Choice = temp - } - break - } - } - } - - // Begin Validation logic - if s.EarliestParsed > s.LatestParsed { - log.Errorf("Earliest time cannot be greater than latest for sample '%s', disabling Sample", s.Name) - s.Disabled = true - return - } - // If no interval is set, generate one time and exit - if s.Interval == 0 && s.Generator != "replay" { - log.Infof("No interval set for sample '%s', setting endIntervals to 1", s.Name) - s.EndIntervals = 1 - } - for i, t := range s.Tokens { - switch t.Type { - case "random", "rated": - if t.Replacement == "int" || t.Replacement == "float" { - if t.Lower > t.Upper { - log.Errorf("Lower cannot be greater than Upper for token '%s' in sample '%s', disabling Sample", t.Name, s.Name) - s.Disabled = true - } else if t.Upper == 0 { - log.Errorf("Upper cannot be zero for token '%s' in sample '%s', disabling Sample", t.Name, s.Name) - s.Disabled = true - } - } else if t.Replacement == "string" || t.Replacement == "hex" { - if t.Length == 0 { - log.Errorf("Length cannot be zero for token '%s' in sample '%s', disabling Sample", t.Name, s.Name) - s.Disabled = true - } - } else { - if t.Replacement != "guid" && t.Replacement != "ipv4" && t.Replacement != "ipv6" { - log.Errorf("Replacement '%s' is invalid for token '%s' in sample '%s'", t.Replacement, t.Name, s.Name) - s.Disabled = true - } - } - case "choice": - if len(t.Choice) == 0 || t.Choice == nil { - log.Errorf("Zero choice items for token '%s' in sample '%s', disabling Sample", t.Name, s.Name) - s.Disabled = true - } - case "weightedChoice": - if len(t.WeightedChoice) == 0 || t.WeightedChoice == nil { - log.Errorf("Zero choice items for token '%s' in sample '%s', disabling Sample", t.Name, s.Name) - s.Disabled = true - } - case "fieldChoice": - if len(t.FieldChoice) == 0 || t.FieldChoice == nil { - log.Errorf("Zero choice items for token '%s' in sample '%s', disabling Sample", t.Name, s.Name) - s.Disabled = true - } - for _, choice := range t.FieldChoice { - if _, ok := choice[t.SrcField]; !ok { - log.Errorf("Source field '%s' does not exist for token '%s' in row '%#v' in sample '%s', disabling Sample", t.SrcField, t.Name, choice, s.Name) - s.Disabled = true - break - } - } - case "script": - s.Tokens[i].mutex = &sync.Mutex{} - for k, v := range t.Init { - vAsNum, err := strconv.ParseFloat(v, 64) - if err != nil { - t.luaState.RawSet(lua.LString(k), lua.LNumber(vAsNum)) - } else { - t.luaState.RawSet(lua.LString(k), lua.LString(v)) - } - } - } - } - - // Check if we are able to do singlepass on this sample by looping through all lines - // and ensuring we can match all the tokens on each line - if !s.Disabled { - s.SinglePass = true - - var tlines []map[string]tokenspos - - outer: - for _, l := range s.Lines { - tp := make(map[string]tokenspos) - for j, t := range s.Tokens { - // tokenpos 0 first char, 1 last char, 2 token # - var pos tokenpos - var err error - offsets, err := t.GetReplacementOffsets(l[t.Field]) - if err != nil || len(offsets) == 0 { - log.Infof("Error getting replacements for token '%s' in event '%s', disabling SinglePass", t.Name, l[t.Field]) - s.SinglePass = false - break outer - } - for _, offset := range offsets { - pos1 := offset[0] - pos2 := offset[1] - if pos1 < 0 || pos2 < 0 { - log.Infof("Token '%s' not found in event '%s', disabling SinglePass", t.Name, l) - s.SinglePass = false - break outer - } - pos.Pos1 = pos1 - pos.Pos2 = pos2 - pos.Token = j - tp[t.Field] = append(tp[t.Field], pos) - } - } - - // Ensure we don't have any tokens overlapping one another for singlepass - for _, v := range tp { - sort.Sort(v) - - lastpos := 0 - lasttoken := "" - maxpos := 0 - for _, pos := range v { - // Does the beginning of this token overlap with the end of the last? - if lastpos > pos.Pos1 { - log.Infof("Token '%s' extends beyond beginning of token '%s', disabling SinglePass", lasttoken, s.Tokens[pos.Token].Name) - s.SinglePass = false - break outer - } - // Does the beginning of this token happen before the max we've seen a token before? - if maxpos > pos.Pos1 { - log.Infof("Some former token extends beyond the beginning of token '%s', disabling SinglePass", s.Tokens[pos.Token].Name) - s.SinglePass = false - break outer - } - if pos.Pos2 > maxpos { - maxpos = pos.Pos2 - } - lastpos = pos.Pos2 - lasttoken = s.Tokens[pos.Token].Name - } - } - tlines = append(tlines, tp) - } - - if s.SinglePass { - - // Now loop through each line and each field, breaking it up according to the positions of the tokens - for i, line := range s.Lines { - if len(tlines) >= i && len(tlines) > 0 { - bline := make(map[string][]StringOrToken) - for field := range line { - var bfield []StringOrToken - // Field doesn't exist because no tokens hit that field - if _, ok := tlines[i][field]; !ok { - bf := StringOrToken{T: nil, S: line[field]} - bfield = append(bfield, bf) - } else { - lastpos := 0 - // Here, we need to iterate through all the tokens and add StringOrToken for each match - // Make sure we check for a token a pos 0, we'll put a token first - for _, tp := range tlines[i][field] { - if tp.Pos1 == 0 { - bf := StringOrToken{T: &s.Tokens[tp.Token], S: ""} - bfield = append(bfield, bf) - lastpos = tp.Pos2 - } else { - // Add string from end of last token to the beginning of this one - bf := StringOrToken{T: nil, S: s.Lines[i][field][lastpos:tp.Pos1]} - bfield = append(bfield, bf) - // Add this token - bf = StringOrToken{T: &s.Tokens[tp.Token], S: ""} - bfield = append(bfield, bf) - lastpos = tp.Pos2 - } - } - // Add the last string if the last token didn't cover to the end of the string - if lastpos < len(s.Lines[i][field]) { - bf := StringOrToken{T: nil, S: s.Lines[i][field][lastpos:]} - bfield = append(bfield, bf) - } - } - bline[field] = bfield - } - s.BrokenLines = append(s.BrokenLines, bline) - } - } - } - } - - if s.Generator == "replay" { - // For replay, loop through all events, attempt to find a timestamp in each row, store sleep times in a data structure - s.ReplayOffsets = make([]time.Duration, len(s.Lines)) - var lastts time.Time - var avgOffset time.Duration - outer2: - for i := 0; i < len(s.Lines); i++ { - inner2: - for _, t := range s.Tokens { - if t.Type == "timestamp" || t.Type == "gotimestamp" || t.Type == "epochtimestamp" { - offsets, err := t.GetReplacementOffsets(s.Lines[i][t.Field]) - if err != nil || len(offsets) == 0 { - log.WithFields(log.Fields{ - "token": t.Name, - "sample": s.Name, - "err": err, - }).Errorf("Error getting timestamp offsets, disabling sample") - s.Disabled = true - break outer2 - } - pos1 := offsets[0][0] - pos2 := offsets[0][1] - ts, err := t.ParseTimestamp(s.Lines[i][t.Field][pos1:pos2]) - if err != nil { - log.WithFields(log.Fields{ - "token": t.Name, - "sample": s.Name, - "err": err, - "event": s.Lines[0][t.Field], - }).Errorf("Error parsing timestamp, disabling sample") - s.Disabled = true - break outer2 - } - if i > 0 { - s.ReplayOffsets[i-1] = lastts.Sub(ts) * -1 - avgOffset = (avgOffset + s.ReplayOffsets[i-1]) / 2 - } - lastts = ts - break inner2 - } - } - s.ReplayOffsets[len(s.ReplayOffsets)-1] = avgOffset - } - log.WithFields(log.Fields{ - "sample": s.Name, - "ReplayOffsets": s.ReplayOffsets, - }).Debugf("ReplayOffsets values") - } else if s.Generator != "sample" { - for _, g := range c.Generators { - // TODO If not single threaded, we won't establish state in the sample object - if g.Name == s.Generator { - s.LuaMutex = &sync.Mutex{} - s.CustomGenerator = g - if g.SingleThreaded { - s.GeneratorState = NewGeneratorState(s) - } - } - } - if s.CustomGenerator == nil { - log.Errorf("Generator '%s' not found for sample '%s', disabling sample", s.Generator, s.Name) - s.Disabled = true - } - } - } -} - -// Returns a copy of the rater with the Options properly cast -func (c *Config) validateRater(r *RaterConfig) { - configRaterKeys := map[string]bool{ - "HourOfDay": true, - "MinuteOfHour": true, - "DayOfWeek": true, - } - - opt := make(map[string]interface{}) - for k, v := range r.Options { - var newvset interface{} - if configRaterKeys[k] { - newv := make(map[int]float64) - vcast := v.(map[interface{}]interface{}) - for k2, v2 := range vcast { - k2int := k2.(int) - v2float, ok := v2.(float64) - if !ok { - v2int, ok := v2.(int) - if !ok { - log.Fatalf("Rater value '%#v' of key '%#v' for rater '%s' in '%s' is not a float or int", v2, k2, r.Name, k) - } - v2float = float64(v2int) - } - newv[k2int] = v2float - } - newvset = newv - } else { - newvset = v - } - opt[k] = newvset - } - r.Options = opt -} - // Brings in a Generator script from a file func (c *Config) readGenerator(configDir string, g *GeneratorConfig) error { // First try to find the file by absolute path @@ -968,7 +493,7 @@ func (c *Config) readGenerator(configDir string, g *GeneratorConfig) error { } else if err != nil { return err } - contents, err := ioutil.ReadFile(fullPath) + contents, err := os.ReadFile(fullPath) if err != nil { return err } @@ -1030,211 +555,6 @@ func ParseBeginEnd(s *Sample) { log.Infof("Beginning generation at %s; Ending at %s; Realtime: %v", s.BeginParsed, s.EndParsed, s.Realtime) } -// SetupSystemTokens adds tokens like time and facility to samples based on configuration -func (c *Config) SetupSystemTokens() { - addToken := func(s *Sample, tokenName string, tokenType string, tokenReplacement string) { - // If there's no _time token, add it to make sure we have a timestamp field in every event - tokenfound := false - for _, t := range s.Tokens { - if t.Name == tokenName { - tokenfound = true - } - } - for _, l := range s.Lines { - if _, ok := l[tokenName]; ok { - tokenfound = true - } - } - if !tokenfound { - log.Infof("Adding %s token for sample %s", tokenName, s.Name) - tt := Token{ - Name: tokenName, - Type: tokenType, - Format: "template", - Field: tokenName, - Token: fmt.Sprintf("$%s$", tokenName), - Group: -1, - Parent: s, - } - if tokenReplacement != "" { - tt.Replacement = tokenReplacement - } - s.Tokens = append(s.Tokens, tt) - if s.SinglePass { - for j := 0; j < len(s.BrokenLines); j++ { - st := []StringOrToken{ - StringOrToken{T: &tt, S: ""}, - } - s.BrokenLines[j][tokenName] = st - } - } - for j := 0; j < len(s.Lines); j++ { - s.Lines[j][tokenName] = fmt.Sprintf("$%s$", tokenName) - } - } - } - addField := func(s *Sample, name string, value string) { - log.Infof("Adding %s field for sample %s", name, s.Name) - for i := 0; i < len(s.Lines); i++ { - if s.Lines[i][name] == "" { - s.Lines[i][name] = value - } - } - if s.SinglePass { - for i := 0; i < len(s.BrokenLines); i++ { - st := []StringOrToken{ - StringOrToken{T: nil, S: value}, - } - if _, ok := s.BrokenLines[i][name]; !ok { - s.BrokenLines[i][name] = st - } - } - } - } - syslogOutput := c.Global.Output.OutputTemplate == "rfc3164" || c.Global.Output.OutputTemplate == "rfc5424" - addTime := c.Global.Output.OutputTemplate == "splunkhec" || - c.Global.Output.OutputTemplate == "elasticsearch" || - c.Global.AddTime || - syslogOutput - if !c.cc.Export && addTime { - // Use epochtimestamp for Splunk, or different formats for rfc3164 or rfc5424 - var tokenType string - var tokenReplacement string - tokenName := "_time" - if c.Global.Output.OutputTemplate == "elasticsearch" { - tokenName = "@timestamp" - tokenType = "gotimestamp" - tokenReplacement = "2006-01-02T15:04:05.999Z07:00" - } else if !syslogOutput { - tokenType = "epochtimestamp" - } else if c.Global.Output.OutputTemplate == "rfc3164" { - tokenType = "gotimestamp" - tokenReplacement = "Jan _2 15:04:05" - } else if c.Global.Output.OutputTemplate == "rfc5424" { - tokenType = "gotimestamp" - tokenReplacement = "2006-01-02T15:04:05.999999Z07:00" - } - for i := 0; i < len(c.Samples); i++ { - s := c.Samples[i] - addToken(s, tokenName, tokenType, tokenReplacement) // Timestamp - // Add fields for syslog output - if syslogOutput { - addField(s, "priority", fmt.Sprintf("%d", defaultSyslogPriority)) - hostname, _ := os.Hostname() - addField(s, "host", hostname) - tag := "gogen" - if len(s.Lines) > 0 && s.Lines[0]["sourcetype"] != "" { - tag = s.Lines[0]["sourcetype"] - } - addField(s, "tag", tag) - addField(s, "pid", fmt.Sprintf("%d", os.Getpid())) - addField(s, "appName", "gogen") - } - // Fixup existing timestamp tokens to all use the same static group, -1 - for j := 0; j < len(s.Tokens); j++ { - if s.Tokens[j].Type == "timestamp" || s.Tokens[j].Type == "gotimestamp" || s.Tokens[j].Type == "epochtimestamp" { - s.Tokens[j].Group = -1 - } - } - } - } -} - -func (c *Config) walkPath(fullPath string, acceptableExtensions map[string]bool, callback func(string) error) error { - log.Debugf("walkPath '%s' for extensions: '%v'", fullPath, acceptableExtensions) - fullPath = os.ExpandEnv(fullPath) - info, err := os.Stat(fullPath) - if err != nil { - return err - } - if info.IsDir() { - fullPath += string(filepath.Separator) - } - // filepath.Walk(os.ExpandEnv(fullPath), func(path string, _ os.FileInfo, err error) error { - // log.Debugf("Walking, at %s", path) - // if os.IsNotExist(err) { - // return nil - // } else if err != nil { - // log.Errorf("Error from WalkFunc: %s", err) - // return err - // } - // // Check if extension is acceptable before attempting to parse - // if acceptableExtensions[filepath.Ext(path)] { - // return callback(path) - // } - // return nil - // }) - files, err := filepath.Glob(fullPath + "*") - if err != nil { - return err - } - for _, path := range files { - // log.Debugf("Walking, at %s", path) - if acceptableExtensions[filepath.Ext(path)] { - err := callback(path) - if err != nil { - return err - } - } - } - return nil -} - -func (c *Config) parseFileConfig(out interface{}, path ...string) error { - fullPath := filepath.Join(path...) - log.Debugf("Config Path: %v", fullPath) - if _, err := os.Stat(fullPath); os.IsNotExist(err) { - return err - } - - contents, err := ioutil.ReadFile(fullPath) - if err != nil { - return err - } - - // log.Debugf("Contents: %s", contents) - switch filepath.Ext(fullPath) { - case ".yml", ".yaml": - if err := yaml.Unmarshal(contents, out); err != nil { - if ute, ok := err.(*json.UnmarshalTypeError); ok { - log.Errorf("JSON parsing error in file '%s' at offset %d: %v", fullPath, ute.Offset, ute) - } else { - log.Errorf("YAML parsing error in file '%s': %v", fullPath, err) - } - } - case ".json": - if err := json.Unmarshal(contents, out); err != nil { - if ute, ok := err.(*json.UnmarshalTypeError); ok { - log.Errorf("JSON parsing error in file '%s' at offset %d: %v", fullPath, ute.Offset, ute) - } else { - log.Errorf("JSON parsing error in file '%s': %v", fullPath, err) - } - } - } - // log.Debugf("Out: %#v\n", out) - return nil -} - -func (c *Config) parseWebConfig(out interface{}, url string) error { - resp, err := http.Get(url) - if err != nil { - return err - } - contents, err := ioutil.ReadAll(resp.Body) - if err != nil { - return err - } - // Try YAML then JSON - err = yaml.Unmarshal(contents, out) - if err != nil { - err = json.Unmarshal(contents, out) - if err != nil { - return err - } - } - return nil -} - // FindSampleByName finds and returns a pointer to a sample referenced by the passed name func (c Config) FindSampleByName(name string) *Sample { for i := 0; i < len(c.Samples); i++ { @@ -1245,7 +565,7 @@ func (c Config) FindSampleByName(name string) *Sample { return nil } -// covertUTC sets time local to UTC if configured as UTC +// convertUTC sets time local to UTC if configured as UTC func convertUTC(t time.Time) time.Time { if instance != nil { if instance.Global.UTC { @@ -1268,7 +588,7 @@ func (c *Config) Clean() { debug.FreeOSMemory() } -// WriteFileFromString writes a configuration string to a temporary file and returns the filename +// WriteTempConfigFileFromString writes a configuration string to a temporary file and returns the filename func WriteTempConfigFileFromString(config string) string { tmpfile, err := os.CreateTemp("", "gogen-test-*.yml") if err != nil { diff --git a/internal/config_parse.go b/internal/config_parse.go new file mode 100644 index 0000000..f8ce25f --- /dev/null +++ b/internal/config_parse.go @@ -0,0 +1,95 @@ +package internal + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + + log "github.com/coccyx/gogen/logger" + yaml "gopkg.in/yaml.v2" +) + +func (c *Config) parseFileConfig(out interface{}, path ...string) error { + fullPath := filepath.Join(path...) + log.Debugf("Config Path: %v", fullPath) + + contents, err := os.ReadFile(fullPath) + if err != nil { + return err + } + + var parseErr error + switch filepath.Ext(fullPath) { + case ".yml", ".yaml": + parseErr = yaml.Unmarshal(contents, out) + case ".json": + parseErr = json.Unmarshal(contents, out) + } + if parseErr != nil { + return fmt.Errorf("parsing error in file '%s': %w", fullPath, parseErr) + } + return nil +} + +func (c *Config) parseWebConfig(out interface{}, url string) error { + resp, err := http.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + contents, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + // Try YAML then JSON + err = yaml.Unmarshal(contents, out) + if err != nil { + err = json.Unmarshal(contents, out) + if err != nil { + return err + } + } + return nil +} + +func (c *Config) walkPath(fullPath string, acceptableExtensions map[string]bool, callback func(string) error) error { + log.Debugf("walkPath '%s' for extensions: '%v'", fullPath, acceptableExtensions) + fullPath = os.ExpandEnv(fullPath) + info, err := os.Stat(fullPath) + if err != nil { + return err + } + if info.IsDir() { + fullPath += string(filepath.Separator) + } + files, err := filepath.Glob(fullPath + "*") + if err != nil { + return err + } + for _, path := range files { + if acceptableExtensions[filepath.Ext(path)] { + err := callback(path) + if err != nil { + return err + } + } + } + return nil +} + +// loadConfigDir reads all config files from configDir/subDir and appends parsed items to dest. +func loadConfigDir[T any](c *Config, configDir, subDir string, dest *[]*T) { + fullPath := filepath.Join(configDir, subDir) + c.walkPath(fullPath, configExtensions, func(innerPath string) error { + item := new(T) + if err := c.parseFileConfig(item, innerPath); err != nil { + log.Errorf("Error parsing config %s: %s", innerPath, err) + return err + } + *dest = append(*dest, item) + return nil + }) +} diff --git a/internal/config_system_tokens.go b/internal/config_system_tokens.go new file mode 100644 index 0000000..4c84f3b --- /dev/null +++ b/internal/config_system_tokens.go @@ -0,0 +1,122 @@ +package internal + +import ( + "fmt" + "os" + + log "github.com/coccyx/gogen/logger" +) + +// SetupSystemTokens adds tokens like time and facility to samples based on configuration +func (c *Config) SetupSystemTokens() { + addToken := func(s *Sample, tokenName string, tokenType string, tokenReplacement string) { + // If there's no _time token, add it to make sure we have a timestamp field in every event + tokenfound := false + for _, t := range s.Tokens { + if t.Name == tokenName { + tokenfound = true + break + } + } + if !tokenfound { + for _, l := range s.Lines { + if _, ok := l[tokenName]; ok { + tokenfound = true + break + } + } + } + if !tokenfound { + log.Infof("Adding %s token for sample %s", tokenName, s.Name) + tt := Token{ + Name: tokenName, + Type: tokenType, + Format: "template", + Field: tokenName, + Token: fmt.Sprintf("$%s$", tokenName), + Group: -1, + Parent: s, + } + if tokenReplacement != "" { + tt.Replacement = tokenReplacement + } + s.Tokens = append(s.Tokens, tt) + if s.SinglePass { + for j := 0; j < len(s.BrokenLines); j++ { + st := []StringOrToken{ + {T: &tt, S: ""}, + } + s.BrokenLines[j][tokenName] = st + } + } + for j := 0; j < len(s.Lines); j++ { + s.Lines[j][tokenName] = fmt.Sprintf("$%s$", tokenName) + } + } + } + addField := func(s *Sample, name string, value string) { + log.Infof("Adding %s field for sample %s", name, s.Name) + for i := 0; i < len(s.Lines); i++ { + if s.Lines[i][name] == "" { + s.Lines[i][name] = value + } + } + if s.SinglePass { + for i := 0; i < len(s.BrokenLines); i++ { + st := []StringOrToken{ + {T: nil, S: value}, + } + if _, ok := s.BrokenLines[i][name]; !ok { + s.BrokenLines[i][name] = st + } + } + } + } + syslogOutput := c.Global.Output.OutputTemplate == "rfc3164" || c.Global.Output.OutputTemplate == "rfc5424" + addTime := c.Global.Output.OutputTemplate == "splunkhec" || + c.Global.Output.OutputTemplate == "elasticsearch" || + c.Global.AddTime || + syslogOutput + if !c.cc.Export && addTime { + // Use epochtimestamp for Splunk, or different formats for rfc3164 or rfc5424 + var tokenType string + var tokenReplacement string + tokenName := "_time" + if c.Global.Output.OutputTemplate == "elasticsearch" { + tokenName = "@timestamp" + tokenType = "gotimestamp" + tokenReplacement = "2006-01-02T15:04:05.999Z07:00" + } else if !syslogOutput { + tokenType = "epochtimestamp" + } else if c.Global.Output.OutputTemplate == "rfc3164" { + tokenType = "gotimestamp" + tokenReplacement = "Jan _2 15:04:05" + } else if c.Global.Output.OutputTemplate == "rfc5424" { + tokenType = "gotimestamp" + tokenReplacement = "2006-01-02T15:04:05.999999Z07:00" + } + hostname, _ := os.Hostname() + for i := 0; i < len(c.Samples); i++ { + s := c.Samples[i] + addToken(s, tokenName, tokenType, tokenReplacement) // Timestamp + // Add fields for syslog output + if syslogOutput { + addField(s, "priority", fmt.Sprintf("%d", defaultSyslogPriority)) + addField(s, "host", hostname) + tag := "gogen" + if len(s.Lines) > 0 && s.Lines[0]["sourcetype"] != "" { + tag = s.Lines[0]["sourcetype"] + } + addField(s, "tag", tag) + addField(s, "pid", fmt.Sprintf("%d", os.Getpid())) + addField(s, "appName", "gogen") + } + // Fixup existing timestamp tokens to all use the same static group, -1 + for j := 0; j < len(s.Tokens); j++ { + if s.Tokens[j].Type == "timestamp" || s.Tokens[j].Type == "gotimestamp" || s.Tokens[j].Type == "epochtimestamp" { + s.Tokens[j].Group = -1 + } + } + } + } +} diff --git a/internal/config_test.go b/internal/config_test.go index 8e565af..205128d 100644 --- a/internal/config_test.go +++ b/internal/config_test.go @@ -2,6 +2,8 @@ package internal import ( "math/rand" + "net/http" + "net/http/httptest" "os" "path/filepath" "reflect" @@ -356,3 +358,1296 @@ func TestParseWebConfig(t *testing.T) { assert.Equal(t, "timestamp", tsToken.Type) assert.Equal(t, "%d/%b/%Y %H:%M:%S:%L", tsToken.Replacement) } + +func TestFindRater(t *testing.T) { + os.Setenv("GOGEN_HOME", "..") + os.Setenv("GOGEN_ALWAYS_REFRESH", "1") + os.Setenv("GOGEN_FULLCONFIG", filepath.Join("..", "tests", "rater", "configrater.yml")) + + c := NewConfig() + + r := c.FindRater("testconfigrater") + assert.NotNil(t, r) + assert.Equal(t, "testconfigrater", r.Name) + + CleanupConfigAndEnvironment() +} + +func TestFindRaterNotFound(t *testing.T) { + os.Setenv("GOGEN_HOME", "..") + os.Setenv("GOGEN_ALWAYS_REFRESH", "1") + os.Setenv("GOGEN_FULLCONFIG", filepath.Join("..", "tests", "rater", "configrater.yml")) + + c := NewConfig() + + r := c.FindRater("nonexistentrater") + assert.Nil(t, r) + + CleanupConfigAndEnvironment() +} + +func TestFindSampleByNameNotFound(t *testing.T) { + os.Setenv("GOGEN_HOME", "..") + os.Setenv("GOGEN_ALWAYS_REFRESH", "1") + os.Setenv("GOGEN_FULLCONFIG", filepath.Join("..", "tests", "rater", "configrater.yml")) + + c := NewConfig() + + s := c.FindSampleByName("nonexistentsample") + assert.Nil(t, s) + + CleanupConfigAndEnvironment() +} + +func TestClean(t *testing.T) { + configStr := ` +samples: + - name: enabled-sample + interval: 1 + count: 1 + endIntervals: 1 + lines: + - _raw: test + - name: disabled-sample + disabled: true + interval: 1 + count: 1 + lines: + - _raw: test +` + ResetConfig() + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + + // After Clean(), only enabled real samples should remain + found := false + for _, s := range c.Samples { + if s.Name == "disabled-sample" { + found = true + } + } + assert.False(t, found, "disabled sample should be removed by Clean()") + + foundEnabled := false + for _, s := range c.Samples { + if s.Name == "enabled-sample" { + foundEnabled = true + } + } + assert.True(t, foundEnabled, "enabled sample should remain after Clean()") +} + +func TestParseBeginEndWithEndIntervals(t *testing.T) { + s := &Sample{ + Name: "test", + EndIntervals: 3, + Interval: 5, + } + + ParseBeginEnd(s) + + assert.Equal(t, "-15s", s.Begin) + assert.Equal(t, "now", s.End) + assert.False(t, s.Realtime) + assert.False(t, s.BeginParsed.IsZero()) + assert.False(t, s.EndParsed.IsZero()) +} + +func TestParseBeginEndEmptyEnd(t *testing.T) { + s := &Sample{ + Name: "test", + End: "", + } + + ParseBeginEnd(s) + + // Empty end means realtime + assert.True(t, s.Realtime) + assert.True(t, s.EndParsed.IsZero()) +} + +func TestParseBeginEndBeginOverridesRealtime(t *testing.T) { + s := &Sample{ + Name: "test", + Begin: "-60s", + End: "", + } + + ParseBeginEnd(s) + + // Begin set without endIntervals: sets Realtime to false via parsing + assert.False(t, s.Realtime) + assert.False(t, s.BeginParsed.IsZero()) +} + +func TestSetupFromFile(t *testing.T) { + SetupFromFile("/tmp/testfile.yml") + defer CleanupConfigAndEnvironment() + + assert.Equal(t, "..", os.Getenv("GOGEN_HOME")) + assert.Equal(t, "1", os.Getenv("GOGEN_ALWAYS_REFRESH")) + assert.Equal(t, "/tmp/testfile.yml", os.Getenv("GOGEN_FULLCONFIG")) +} + +func TestSetupSystemTokensSplunkHEC(t *testing.T) { + ResetConfig() + + configStr := ` +global: + output: + outputter: stdout + outputTemplate: splunkhec +samples: + - name: hectokensample + interval: 1 + count: 1 + endIntervals: 1 + lines: + - _raw: test event +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + s := c.FindSampleByName("hectokensample") + assert.NotNil(t, s) + + // Should have a _time token added by SetupSystemTokens + foundTime := false + for _, tk := range s.Tokens { + if tk.Name == "_time" { + foundTime = true + assert.Equal(t, "epochtimestamp", tk.Type) + } + } + assert.True(t, foundTime, "splunkhec should add _time token") +} + +func TestSetupSystemTokensElasticsearch(t *testing.T) { + ResetConfig() + + configStr := ` +global: + output: + outputter: stdout + outputTemplate: elasticsearch +samples: + - name: estokensample + interval: 1 + count: 1 + endIntervals: 1 + lines: + - _raw: test event +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + s := c.FindSampleByName("estokensample") + assert.NotNil(t, s) + + foundTimestamp := false + for _, tk := range s.Tokens { + if tk.Name == "@timestamp" { + foundTimestamp = true + assert.Equal(t, "gotimestamp", tk.Type) + } + } + assert.True(t, foundTimestamp, "elasticsearch should add @timestamp token") +} + +func TestSetupSystemTokensRFC3164(t *testing.T) { + ResetConfig() + + configStr := ` +global: + output: + outputter: stdout + outputTemplate: rfc3164 +samples: + - name: rfc3164sample + interval: 1 + count: 1 + endIntervals: 1 + lines: + - _raw: syslog event +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + s := c.FindSampleByName("rfc3164sample") + assert.NotNil(t, s) + + foundTime := false + foundPriority := false + foundHost := false + foundTag := false + foundPid := false + for _, tk := range s.Tokens { + if tk.Name == "_time" { + foundTime = true + assert.Equal(t, "gotimestamp", tk.Type) + } + } + assert.True(t, foundTime, "rfc3164 should add _time token") + + // Check that syslog fields were added to lines + if len(s.Lines) > 0 { + if _, ok := s.Lines[0]["priority"]; ok { + foundPriority = true + } + if _, ok := s.Lines[0]["host"]; ok { + foundHost = true + } + if _, ok := s.Lines[0]["tag"]; ok { + foundTag = true + } + if _, ok := s.Lines[0]["pid"]; ok { + foundPid = true + } + } + assert.True(t, foundPriority, "rfc3164 should add priority field") + assert.True(t, foundHost, "rfc3164 should add host field") + assert.True(t, foundTag, "rfc3164 should add tag field") + assert.True(t, foundPid, "rfc3164 should add pid field") +} + +func TestSetupSystemTokensRFC5424(t *testing.T) { + ResetConfig() + + configStr := ` +global: + output: + outputter: stdout + outputTemplate: rfc5424 +samples: + - name: rfc5424sample + interval: 1 + count: 1 + endIntervals: 1 + lines: + - _raw: syslog5424 event +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + s := c.FindSampleByName("rfc5424sample") + assert.NotNil(t, s) + + foundTime := false + for _, tk := range s.Tokens { + if tk.Name == "_time" { + foundTime = true + assert.Equal(t, "gotimestamp", tk.Type) + } + } + assert.True(t, foundTime, "rfc5424 should add _time token") + + if len(s.Lines) > 0 { + _, hasAppName := s.Lines[0]["appName"] + assert.True(t, hasAppName, "rfc5424 should add appName field") + } +} + +func TestBuildConfigDefaults(t *testing.T) { + ResetConfig() + + configStr := ` +samples: + - name: defaultsample + interval: 1 + count: 1 + endIntervals: 1 + lines: + - _raw: test +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + + // Check defaults were applied + assert.Equal(t, 1, c.Global.GeneratorWorkers) + assert.Equal(t, 1, c.Global.OutputWorkers) + assert.Equal(t, "stdout", c.Global.Output.Outputter) + assert.Equal(t, "raw", c.Global.Output.OutputTemplate) + assert.Equal(t, 5, c.Global.Output.BackupFiles) + assert.NotZero(t, c.Global.Output.MaxBytes) + assert.NotZero(t, c.Global.Output.BufferBytes) + assert.NotZero(t, c.Global.Output.Timeout) +} + +func TestValidateDisabledNoLines(t *testing.T) { + ResetConfig() + + configStr := ` +samples: + - name: nolines + interval: 1 + count: 1 + endIntervals: 1 +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + // Sample with no lines should be disabled and cleaned away + s := c.FindSampleByName("nolines") + assert.Nil(t, s, "sample with no lines should be removed") +} + +func TestConvertUTC(t *testing.T) { + ResetConfig() + + configStr := ` +global: + utc: true +samples: + - name: utctest + interval: 1 + count: 1 + endIntervals: 1 + lines: + - _raw: test +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + _ = NewConfig() + + now := time.Now() + utcTime := convertUTC(now) + assert.Equal(t, now.UTC(), utcTime) +} + +func TestSampleNow(t *testing.T) { + s := &Sample{ + Realtime: true, + } + beforeCall := time.Now() + result := s.Now() + afterCall := time.Now() + assert.True(t, !result.Before(beforeCall) && !result.After(afterCall), + "Realtime Now() should return current time") + + fixedTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + s.Realtime = false + s.Current = fixedTime + result = s.Now() + assert.Equal(t, fixedTime, result) +} + +func TestReadSamplesDir(t *testing.T) { + ResetConfig() + + // Use the existing test samples directory + os.Setenv("GOGEN_HOME", "..") + os.Setenv("GOGEN_ALWAYS_REFRESH", "1") + os.Setenv("GOGEN_FULLCONFIG", "") + os.Setenv("GOGEN_SAMPLES_DIR", filepath.Join("..", "tests", "tokens")) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + // Should have loaded samples from the tokens test directory + assert.NotEmpty(t, c.Samples, "should load samples from samples dir") +} + +func TestParseFileConfigJSON(t *testing.T) { + ResetConfig() + + // Create a JSON config file with the Config struct format + dir := t.TempDir() + jsonFile := filepath.Join(dir, "test.json") + jsonContent := `{"samples": [{"name": "jsonsample", "interval": 1, "count": 1, "endIntervals": 1, "lines": [{"_raw": "json test"}]}]}` + os.WriteFile(jsonFile, []byte(jsonContent), 0644) + + os.Setenv("GOGEN_HOME", "..") + os.Setenv("GOGEN_ALWAYS_REFRESH", "1") + os.Setenv("GOGEN_FULLCONFIG", jsonFile) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + assert.NotEmpty(t, c.Samples, "should load samples from JSON config") +} + +func TestNegativeCacheIntervals(t *testing.T) { + ResetConfig() + + configStr := ` +global: + cacheIntervals: -5 +samples: + - name: cachesample + interval: 1 + count: 1 + endIntervals: 1 + lines: + - _raw: test +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + assert.Equal(t, 0, c.Global.CacheIntervals, "negative cacheIntervals should be clamped to 0") +} + +func TestReadSamplesDirSampleFile(t *testing.T) { + ResetConfig() + + dir := t.TempDir() + + // Create a .sample file + err := os.WriteFile(filepath.Join(dir, "test.sample"), []byte("line one\nline two\nline three\n"), 0644) + assert.NoError(t, err) + + c := &Config{cc: ConfigConfig{}} + c.readSamplesDir(dir) + + // Should have loaded one sample with 3 lines + found := false + for _, s := range c.Samples { + if s.Name == "test.sample" { + found = true + assert.True(t, s.Disabled, ".sample files should be disabled by default") + assert.Equal(t, 3, len(s.Lines)) + assert.Equal(t, "line one", s.Lines[0]["_raw"]) + assert.Equal(t, "line two", s.Lines[1]["_raw"]) + assert.Equal(t, "line three", s.Lines[2]["_raw"]) + } + } + assert.True(t, found, "should find test.sample") +} + +func TestReadSamplesDirCSVFile(t *testing.T) { + ResetConfig() + + dir := t.TempDir() + + // Create a .csv file with header + csvContent := "name,city,state\nalice,NYC,NY\nbob,LA,CA\n" + err := os.WriteFile(filepath.Join(dir, "test.csv"), []byte(csvContent), 0644) + assert.NoError(t, err) + + c := &Config{cc: ConfigConfig{}} + c.readSamplesDir(dir) + + found := false + for _, s := range c.Samples { + if s.Name == "test.csv" { + found = true + assert.True(t, s.Disabled, ".csv files should be disabled by default") + assert.Equal(t, 2, len(s.Lines)) + assert.Equal(t, "alice", s.Lines[0]["name"]) + assert.Equal(t, "NYC", s.Lines[0]["city"]) + assert.Equal(t, "NY", s.Lines[0]["state"]) + assert.Equal(t, "bob", s.Lines[1]["name"]) + } + } + assert.True(t, found, "should find test.csv") +} + +func TestReadSamplesDirYAMLFile(t *testing.T) { + ResetConfig() + + dir := t.TempDir() + + yamlContent := `name: yamlsample +interval: 1 +count: 1 +lines: + - _raw: yaml test line +` + err := os.WriteFile(filepath.Join(dir, "yamltest.yml"), []byte(yamlContent), 0644) + assert.NoError(t, err) + + c := &Config{cc: ConfigConfig{}} + c.readSamplesDir(dir) + + found := false + for _, s := range c.Samples { + if s.Name == "yamlsample" { + found = true + assert.True(t, s.realSample) + } + } + assert.True(t, found, "should find yamlsample") +} + +func TestReadSamplesDirEmptyDir(t *testing.T) { + ResetConfig() + + dir := t.TempDir() + + c := &Config{cc: ConfigConfig{}} + c.readSamplesDir(dir) + + // Should not crash and should have no samples + assert.Empty(t, c.Samples) +} + +func TestReadGeneratorFallbackPath(t *testing.T) { + ResetConfig() + + dir := t.TempDir() + genDir := filepath.Join(dir, "generators") + os.MkdirAll(genDir, 0755) + + // Create generator script in the fallback directory + scriptContent := `-- test generator\nsetToken("test", "value")\n` + err := os.WriteFile(filepath.Join(genDir, "testgen.lua"), []byte(scriptContent), 0644) + assert.NoError(t, err) + + c := &Config{cc: ConfigConfig{ConfigDir: dir}} + g := &GeneratorConfig{Name: "testgen", FileName: "testgen.lua"} + + err = c.readGenerator(dir, g) + assert.NoError(t, err) + assert.Contains(t, g.Script, "test generator") +} + +func TestReadGeneratorNotFound(t *testing.T) { + ResetConfig() + + dir := t.TempDir() + c := &Config{cc: ConfigConfig{ConfigDir: dir}} + g := &GeneratorConfig{Name: "missing", FileName: "nonexistent.lua"} + + err := c.readGenerator(dir, g) + assert.Error(t, err) + assert.Contains(t, err.Error(), "Cannot find generator file") +} + +func TestValidateTokenRandomString(t *testing.T) { + ResetConfig() + configStr := ` +samples: + - name: randstring + interval: 1 + count: 1 + endIntervals: 1 + tokens: + - name: rs + format: template + token: $rs$ + type: random + replacement: string + length: 10 + lines: + - _raw: $rs$ +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + s := c.FindSampleByName("randstring") + assert.NotNil(t, s, "sample with valid random string token should not be disabled") +} + +func TestValidateTokenRandomStringZeroLength(t *testing.T) { + ResetConfig() + configStr := ` +samples: + - name: randstringbad + interval: 1 + count: 1 + endIntervals: 1 + tokens: + - name: rs + format: template + token: $rs$ + type: random + replacement: string + length: 0 + lines: + - _raw: $rs$ +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + s := c.FindSampleByName("randstringbad") + assert.Nil(t, s, "sample with zero-length random string should be disabled") +} + +func TestValidateTokenReplacementTypes(t *testing.T) { + tests := []struct { + name string + replacement string + extra string + valid bool + }{ + {"hex", "hex", "length: 5", true}, + {"guid", "guid", "", true}, + {"ipv4", "ipv4", "", true}, + {"ipv6", "ipv6", "", true}, + {"invalid", "invalid_replacement_xyz", "", false}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ResetConfig() + extra := "" + if tc.extra != "" { + extra = "\n " + tc.extra + } + configStr := ` +samples: + - name: ` + tc.name + ` + interval: 1 + count: 1 + endIntervals: 1 + tokens: + - name: tk + format: template + token: $tk$ + type: random + replacement: ` + tc.replacement + extra + ` + lines: + - _raw: $tk$ +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + s := c.FindSampleByName(tc.name) + if tc.valid { + assert.NotNil(t, s, "%s should not be disabled", tc.name) + } else { + assert.Nil(t, s, "%s should be disabled", tc.name) + } + }) + } +} + +func TestValidateTokenScript(t *testing.T) { + ResetConfig() + configStr := ` +samples: + - name: scripttest + interval: 1 + count: 1 + endIntervals: 1 + tokens: + - name: sc + format: template + token: $sc$ + type: script + init: + myvar: "42" + scriptSrc: | + return "hello" + lines: + - _raw: $sc$ +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + s := c.FindSampleByName("scripttest") + assert.NotNil(t, s, "sample with script token should not be disabled") + // Check that script token has mutex + for _, tk := range s.Tokens { + if tk.Name == "sc" { + assert.NotNil(t, tk.mutex, "script token should have mutex initialized") + } + } +} + +func TestValidateNoInterval(t *testing.T) { + ResetConfig() + configStr := ` +samples: + - name: nointerval + count: 1 + lines: + - _raw: test +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + s := c.FindSampleByName("nointerval") + assert.NotNil(t, s) + assert.Equal(t, 1, s.EndIntervals, "no interval should auto-set endIntervals to 1") +} + +func TestValidateEmptyName(t *testing.T) { + ResetConfig() + + c := &Config{ + Global: Global{ + Output: Output{ + Outputter: "stdout", + OutputTemplate: "raw", + }, + }, + } + s := &Sample{ + realSample: true, + Name: "", + Lines: []map[string]string{{"_raw": "test"}}, + } + c.validate(s) + assert.True(t, s.Disabled, "sample with empty name should be disabled") +} + +func TestValidateRaterWithIntValues(t *testing.T) { + ResetConfig() + configStr := ` +raters: + - name: testrater + type: config + options: + HourOfDay: + 0: 1 + 12: 2 + DayOfWeek: + 0: 1.5 + 6: 0.5 +samples: + - name: ratertest + interval: 1 + count: 1 + endIntervals: 1 + rater: testrater + lines: + - _raw: test +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + r := c.FindRater("testrater") + assert.NotNil(t, r) + // HourOfDay should be converted to map[int]float64 + hod, ok := r.Options["HourOfDay"].(map[int]float64) + assert.True(t, ok, "HourOfDay should be map[int]float64") + assert.Equal(t, 1.0, hod[0]) + assert.Equal(t, 2.0, hod[12]) +} + +func TestValidateWeightedChoice(t *testing.T) { + ResetConfig() + configStr := ` +samples: + - name: weightsource + disabled: true + lines: + - value: alpha + _weight: "3" + - value: beta + _weight: "7" + - name: weightuser + interval: 1 + count: 1 + endIntervals: 1 + tokens: + - name: wt + format: template + token: $wt$ + type: weightedChoice + sample: weightsource + srcField: value + lines: + - _raw: $wt$ +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + s := c.FindSampleByName("weightuser") + assert.NotNil(t, s) + for _, tk := range s.Tokens { + if tk.Name == "wt" { + assert.NotEmpty(t, tk.WeightedChoice, "should have weighted choices resolved") + assert.Equal(t, 2, len(tk.WeightedChoice)) + } + } +} + +func TestValidateTokenSampleResolution(t *testing.T) { + ResetConfig() + configStr := ` +samples: + - name: choices + disabled: true + lines: + - _raw: alpha + - _raw: beta + - _raw: gamma + - name: resolver + interval: 1 + count: 1 + endIntervals: 1 + tokens: + - name: pick + format: template + token: $pick$ + type: choice + sample: choices + lines: + - _raw: $pick$ +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + s := c.FindSampleByName("resolver") + assert.NotNil(t, s) + for _, tk := range s.Tokens { + if tk.Name == "pick" { + assert.Equal(t, 3, len(tk.Choice), "should resolve 3 choices from sample") + assert.Contains(t, tk.Choice, "alpha") + assert.Contains(t, tk.Choice, "beta") + assert.Contains(t, tk.Choice, "gamma") + } + } +} + +func TestValidateExportMode(t *testing.T) { + ResetConfig() + + dir := t.TempDir() + configFile := filepath.Join(dir, "export.yml") + os.WriteFile(configFile, []byte(` +samples: + - name: exportsample + interval: 1 + count: 1 + lines: + - _raw: test +`), 0644) + + cc := ConfigConfig{ + FullConfig: configFile, + Export: true, + } + c := BuildConfig(cc) + + // In export mode, defaults should NOT be set + assert.Equal(t, 0, c.Global.GeneratorWorkers, "export mode should not set defaults") + assert.Equal(t, "", c.Global.Output.Outputter, "export mode should not set output defaults") +} + +func TestMergeMixConfig(t *testing.T) { + c := &Config{} + nc := &Config{ + Samples: []*Sample{ + {Name: "mixsample", Count: 5, Interval: 2}, + }, + } + m := &Mix{ + Count: 10, + Interval: 3, + Begin: "-60s", + End: "now", + } + c.mergeMixConfig(nc, m) + + assert.Equal(t, 1, len(c.Samples)) + assert.Equal(t, "mixsample", c.Samples[0].Name) + assert.Equal(t, 10, c.Samples[0].Count) + assert.Equal(t, 3, c.Samples[0].Interval) +} + +func TestParseFileConfigYAMLError(t *testing.T) { + ResetConfig() + + dir := t.TempDir() + badFile := filepath.Join(dir, "bad.yml") + // Invalid YAML: tabs mixed with spaces in wrong places + os.WriteFile(badFile, []byte("{\n bad yaml content: [unclosed\n"), 0644) + + c := &Config{cc: ConfigConfig{}} + s := &Sample{} + err := c.parseFileConfig(s, badFile) + assert.Error(t, err) + assert.Contains(t, err.Error(), "parsing error in file") +} + +func TestParseFileConfigJSONError(t *testing.T) { + ResetConfig() + + dir := t.TempDir() + badFile := filepath.Join(dir, "bad.json") + os.WriteFile(badFile, []byte("{invalid json"), 0644) + + c := &Config{cc: ConfigConfig{}} + s := &Sample{} + err := c.parseFileConfig(s, badFile) + assert.Error(t, err) + assert.Contains(t, err.Error(), "parsing error in file") +} + +func TestParseFileConfigNotExists(t *testing.T) { + ResetConfig() + + c := &Config{cc: ConfigConfig{}} + s := &Sample{} + err := c.parseFileConfig(s, "/nonexistent/path/file.yml") + assert.Error(t, err) +} + +func TestParseWebConfigSuccess(t *testing.T) { + ResetConfig() + + yamlContent := ` +name: websample +interval: 1 +count: 1 +lines: + - _raw: web test +` + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(yamlContent)) + })) + defer ts.Close() + + c := &Config{cc: ConfigConfig{}} + s := &Sample{} + err := c.parseWebConfig(s, ts.URL) + assert.NoError(t, err) + assert.Equal(t, "websample", s.Name) +} + +func TestParseWebConfigJSONFallback(t *testing.T) { + ResetConfig() + + jsonContent := `{"name": "jsonsample", "interval": 1}` + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(jsonContent)) + })) + defer ts.Close() + + c := &Config{cc: ConfigConfig{}} + s := &Sample{} + err := c.parseWebConfig(s, ts.URL) + assert.NoError(t, err) + assert.Equal(t, "jsonsample", s.Name) +} + +func TestParseWebConfigBadContent(t *testing.T) { + ResetConfig() + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("<<>>")) + })) + defer ts.Close() + + c := &Config{cc: ConfigConfig{}} + s := &Sample{} + err := c.parseWebConfig(s, ts.URL) + // JSON fallback parse returns an error for non-JSON content + assert.Error(t, err) + assert.Equal(t, "", s.Name, "garbage content should not set sample name") +} + +func TestMergeMixConfigDuplicate(t *testing.T) { + c := &Config{ + Samples: []*Sample{ + {Name: "existing"}, + }, + } + nc := &Config{ + Samples: []*Sample{ + {Name: "existing", Count: 5}, + }, + } + m := &Mix{} + c.mergeMixConfig(nc, m) + + // Should not add duplicate + assert.Equal(t, 1, len(c.Samples)) +} + +func TestGetAPIURLDefault(t *testing.T) { + os.Unsetenv("GOGEN_APIURL") + url := getAPIURL() + assert.Equal(t, "https://api.gogen.io", url) +} + +func TestGetAPIURLCustom(t *testing.T) { + os.Setenv("GOGEN_APIURL", "http://localhost:4000") + defer os.Unsetenv("GOGEN_APIURL") + url := getAPIURL() + assert.Equal(t, "http://localhost:4000", url) +} + +func TestValidateFromSample(t *testing.T) { + ResetConfig() + + configStr := ` +samples: + - name: sourcesample + disabled: true + lines: + - _raw: source line 1 + - _raw: source line 2 + - name: copiedsample + fromSample: sourcesample + interval: 1 + count: 1 + endIntervals: 1 +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + s := c.FindSampleByName("copiedsample") + assert.NotNil(t, s) + assert.Len(t, s.Lines, 2, "copiedsample should have lines from sourcesample") +} + +func TestNewGeneratorStateNumericInit(t *testing.T) { + s := &Sample{ + CustomGenerator: &GeneratorConfig{ + Init: map[string]string{ + "count": "42", + "rate": "3.14", + "label": "hello", + }, + }, + Lines: []map[string]string{ + {"_raw": "line1", "host": "h1"}, + {"_raw": "line2", "host": "h2"}, + }, + } + + gs := NewGeneratorState(s) + assert.NotNil(t, gs.LuaState) + assert.NotNil(t, gs.LuaLines) + + // Numeric values should be stored as LNumber + countVal := gs.LuaState.RawGetString("count") + assert.NotNil(t, countVal) + + // String values should be stored as LString + labelVal := gs.LuaState.RawGetString("label") + assert.NotNil(t, labelVal) + + // Lines table should have entries + assert.Equal(t, 2, gs.LuaLines.Len()) +} + +func TestNewGeneratorStateEmptyInit(t *testing.T) { + s := &Sample{ + CustomGenerator: &GeneratorConfig{ + Init: map[string]string{}, + }, + Lines: []map[string]string{}, + } + + gs := NewGeneratorState(s) + assert.NotNil(t, gs.LuaState) + assert.NotNil(t, gs.LuaLines) + assert.Equal(t, 0, gs.LuaLines.Len()) +} + +func TestBuildConfigExportMode(t *testing.T) { + ResetConfig() + os.Setenv("GOGEN_HOME", "..") + os.Setenv("GOGEN_ALWAYS_REFRESH", "1") + home := filepath.Join("..", "tests", "tokens") + os.Setenv("GOGEN_SAMPLES_DIR", home) + + cc := ConfigConfig{ + SamplesDir: home, + Home: "..", + Export: true, + } + c := BuildConfig(cc) + assert.NotNil(t, c) + // In export mode, samples should have lines populated inline + for _, s := range c.Samples { + if s.Name == "tokens" { + assert.Greater(t, len(s.Lines), 0) + } + } +} + +func TestBuildConfigWithGlobalFile(t *testing.T) { + ResetConfig() + globalFile := filepath.Join("..", "tests", "rater", "defaultrater.yml") + cc := ConfigConfig{ + FullConfig: globalFile, + Home: "..", + } + c := BuildConfig(cc) + assert.NotNil(t, c) +} + +func TestValidateInvalidEarliestTime(t *testing.T) { + ResetConfig() + configStr := ` +global: + rotInterval: 1 + output: + outputter: devnull + outputTemplate: raw +samples: + - name: badtime + description: "Bad earliest time" + earliest: "not_a_valid_time_string!!!" + latest: now + interval: 1 + count: 1 + endIntervals: 1 + lines: + - _raw: test event +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + s := c.FindSampleByName("badtime") + // With invalid earliest, EarliestParsed should default to 0 + assert.Equal(t, time.Duration(0), s.EarliestParsed) +} + +func TestValidateInvalidLatestTime(t *testing.T) { + ResetConfig() + configStr := ` +global: + rotInterval: 1 + output: + outputter: devnull + outputTemplate: raw +samples: + - name: badlatest + description: "Bad latest time" + earliest: now + latest: "not_a_valid_time_string!!!" + interval: 1 + count: 1 + endIntervals: 1 + lines: + - _raw: test event +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + s := c.FindSampleByName("badlatest") + // With invalid latest, LatestParsed should default to 0 + assert.Equal(t, time.Duration(0), s.LatestParsed) +} + +func TestNewConfigNoGogenHome(t *testing.T) { + ResetConfig() + os.Unsetenv("GOGEN_HOME") + os.Setenv("GOGEN_ALWAYS_REFRESH", "1") + os.Unsetenv("GOGEN_FULLCONFIG") + os.Unsetenv("GOGEN_CONFIG_DIR") + os.Unsetenv("GOGEN_SAMPLES_DIR") + + c := NewConfig() + assert.NotNil(t, c) + // When GOGEN_HOME is not set, it should default to "." + assert.Equal(t, ".", os.Getenv("GOGEN_HOME")) +} + +func TestValidateNoLinesDisablesSample(t *testing.T) { + ResetConfig() + configStr := ` +global: + rotInterval: 1 + output: + outputter: devnull + outputTemplate: raw +samples: + - name: nolines + description: "Sample with no lines" + interval: 1 + count: 1 + endIntervals: 1 +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + // After Clean(), disabled samples are removed + // So FindSampleByName should return an empty sample (not in the list) + found := false + for _, s := range c.Samples { + if s.Name == "nolines" { + found = true + } + } + assert.False(t, found, "disabled sample with no lines should be removed by Clean()") +} + +func TestValidateRatedTokenDefaultRater(t *testing.T) { + ResetConfig() + configStr := ` +global: + rotInterval: 1 + output: + outputter: devnull + outputTemplate: raw +samples: + - name: rated_test + description: "Rated token default rater" + interval: 1 + count: 1 + endIntervals: 1 + tokens: + - name: myrated + format: template + type: rated + replacement: int + lower: 0 + upper: 100 + lines: + - _raw: value=$myrated$ +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + s := c.FindSampleByName("rated_test") + // Rated token with no raterString should default to "default" + for _, tok := range s.Tokens { + if tok.Name == "myrated" { + assert.Equal(t, "default", tok.RaterString) + } + } +} + +func TestValidateLuaGenerator(t *testing.T) { + ResetConfig() + configStr := ` +global: + rotInterval: 1 + output: + outputter: devnull + outputTemplate: raw +samples: + - name: luagen + description: "Lua generator sample" + generator: mygen + interval: 1 + count: 1 + endIntervals: 1 + lines: + - _raw: test event +generators: + - name: mygen + script: | + lines = getLines() + return send(lines) +` + SetupFromString(configStr) + defer CleanupConfigAndEnvironment() + + c := NewConfig() + s := c.FindSampleByName("luagen") + assert.NotNil(t, s) + assert.Equal(t, "mygen", s.Generator) + assert.NotNil(t, s.CustomGenerator) +} diff --git a/internal/config_validate.go b/internal/config_validate.go new file mode 100644 index 0000000..23f26fc --- /dev/null +++ b/internal/config_validate.go @@ -0,0 +1,412 @@ +package internal + +import ( + "sort" + "strconv" + "sync" + "time" + + log "github.com/coccyx/gogen/logger" + "github.com/coccyx/timeparser" + lua "github.com/yuin/gopher-lua" +) + +// validate takes a sample and checks against any rules which may cause the configuration to be invalid. +// This hopefully centralizes logic for valid configs, disabling any samples which are not valid and +// preventing this logic from sprawling all over the code base. +// Also finds any references from tokens to other samples and +// updates the token to point to the sample data +// Also fixes up any additional things which are needed, like weighted choice string +// string map to the randutil Choice struct +func (c *Config) validate(s *Sample) { + if !s.realSample { + return + } + + s.Buf = &c.Buf + if s.Generator == "" { + s.Generator = defaultGenerator + } + if len(s.Name) == 0 { + s.Disabled = true + s.realSample = false + } else if len(s.Lines) == 0 && (s.Generator == "sample" || s.Generator == "replay") { + s.Disabled = true + s.realSample = false + log.Errorf("Disabling sample '%s', no lines in sample", s.Name) + } else { + s.realSample = true + } + + // Put the output into the sample for convenience + s.Output = &c.Global.Output + + // Setup defaults + setDefault(&s.Earliest, defaultEarliest) + setDefault(&s.Latest, defaultLatest) + setDefault(&s.Field, DefaultField) + setDefault(&s.RaterString, defaultRater) + + ParseBeginEnd(s) + + // Parse earliest and latest as relative times + n := time.Now() + now := func() time.Time { + return n + } + if p, err := timeparser.TimeParserNow(s.Earliest, now); err != nil { + log.Errorf("Error parsing earliest time '%s' for sample '%s', using Now", s.Earliest, s.Name) + s.EarliestParsed = time.Duration(0) + } else { + s.EarliestParsed = n.Sub(p) * -1 + } + if p, err := timeparser.TimeParserNow(s.Latest, now); err != nil { + log.Errorf("Error parsing latest time '%s' for sample '%s', using Now", s.Latest, s.Name) + s.LatestParsed = time.Duration(0) + } else { + s.LatestParsed = n.Sub(p) * -1 + } + + c.resolveTokenSamples(s) + + if s.EarliestParsed > s.LatestParsed { + log.Errorf("Earliest time cannot be greater than latest for sample '%s', disabling Sample", s.Name) + s.Disabled = true + return + } + if s.Interval == 0 && s.Generator != "replay" { + log.Infof("No interval set for sample '%s', setting endIntervals to 1", s.Name) + s.EndIntervals = 1 + } + + c.validateTokens(s) + c.computeSinglePass(s) + c.setupGenerator(s) +} + +// resolveTokenSamples resolves references from tokens to other samples, +// setting up Choice, WeightedChoice, or FieldChoice data on each token. +func (c *Config) resolveTokenSamples(s *Sample) { + for i := 0; i < len(s.Tokens); i++ { + if s.Tokens[i].Type == "rated" && s.Tokens[i].RaterString == "" { + s.Tokens[i].RaterString = "default" + } + if s.Tokens[i].Field == "" { + s.Tokens[i].Field = s.Field + } + // If format is template, then create a default token of $tokenname$ + if s.Tokens[i].Format == "template" && s.Tokens[i].Token == "" { + s.Tokens[i].Token = "$" + s.Tokens[i].Name + "$" + } + s.Tokens[i].Parent = s + s.Tokens[i].luaState = new(lua.LTable) + for j := 0; j < len(c.Samples); j++ { + if s.Tokens[i].SampleString != "" && s.Tokens[i].SampleString == c.Samples[j].Name { + log.Debugf("Resolving sample '%s' for token '%s'", c.Samples[j].Name, s.Tokens[i].Name) + s.Tokens[i].Sample = c.Samples[j] + // See if a field exists other than _raw, if so, FieldChoice + otherfield := false + if len(c.Samples[j].Lines) > 0 { + for k := range c.Samples[j].Lines[0] { + if k != "_raw" { + otherfield = true + break + } + } + } + if otherfield { + // If we're a structured sample and we contain the field "_weight", then we create a weighted choice struct + // Otherwise we're a fieldChoice + _, ok := c.Samples[j].Lines[0]["_weight"] + _, ok2 := c.Samples[j].Lines[0][s.Tokens[i].SrcField] + if ok && ok2 { + for _, line := range c.Samples[j].Lines { + weight, err := strconv.Atoi(line["_weight"]) + if err != nil { + weight = 0 + } + s.Tokens[i].WeightedChoice = append(s.Tokens[i].WeightedChoice, WeightedChoice{Weight: weight, Choice: line[s.Tokens[i].SrcField]}) + } + } else { + s.Tokens[i].FieldChoice = c.Samples[j].Lines + } + } else { + temp := make([]string, 0, len(c.Samples[j].Lines)) + for _, line := range c.Samples[j].Lines { + if _, ok := line["_raw"]; ok { + if len(line["_raw"]) > 0 { + temp = append(temp, line["_raw"]) + } + } + } + s.Tokens[i].Choice = temp + } + break + } + } + } +} + +// validateTokens checks token configurations for validity, disabling the sample if any token is invalid. +func (c *Config) validateTokens(s *Sample) { + for i, t := range s.Tokens { + switch t.Type { + case "random", "rated": + if t.Replacement == "int" || t.Replacement == "float" { + if t.Lower > t.Upper { + log.Errorf("Lower cannot be greater than Upper for token '%s' in sample '%s', disabling Sample", t.Name, s.Name) + s.Disabled = true + } else if t.Upper == 0 { + log.Errorf("Upper cannot be zero for token '%s' in sample '%s', disabling Sample", t.Name, s.Name) + s.Disabled = true + } + } else if t.Replacement == "string" || t.Replacement == "hex" { + if t.Length == 0 { + log.Errorf("Length cannot be zero for token '%s' in sample '%s', disabling Sample", t.Name, s.Name) + s.Disabled = true + } + } else { + if t.Replacement != "guid" && t.Replacement != "ipv4" && t.Replacement != "ipv6" { + log.Errorf("Replacement '%s' is invalid for token '%s' in sample '%s'", t.Replacement, t.Name, s.Name) + s.Disabled = true + } + } + case "choice": + if len(t.Choice) == 0 || t.Choice == nil { + log.Errorf("Zero choice items for token '%s' in sample '%s', disabling Sample", t.Name, s.Name) + s.Disabled = true + } + case "weightedChoice": + if len(t.WeightedChoice) == 0 || t.WeightedChoice == nil { + log.Errorf("Zero choice items for token '%s' in sample '%s', disabling Sample", t.Name, s.Name) + s.Disabled = true + } + case "fieldChoice": + if len(t.FieldChoice) == 0 || t.FieldChoice == nil { + log.Errorf("Zero choice items for token '%s' in sample '%s', disabling Sample", t.Name, s.Name) + s.Disabled = true + } + for _, choice := range t.FieldChoice { + if _, ok := choice[t.SrcField]; !ok { + log.Errorf("Source field '%s' does not exist for token '%s' in row '%#v' in sample '%s', disabling Sample", t.SrcField, t.Name, choice, s.Name) + s.Disabled = true + break + } + } + case "script": + s.Tokens[i].mutex = &sync.Mutex{} + for k, v := range t.Init { + vAsNum, err := strconv.ParseFloat(v, 64) + if err != nil { + t.luaState.RawSet(lua.LString(k), lua.LNumber(vAsNum)) + } else { + t.luaState.RawSet(lua.LString(k), lua.LString(v)) + } + } + } + } +} + +// computeSinglePass checks if SinglePass optimization is feasible for the sample +// by verifying all tokens can be located in each line without overlapping. +func (c *Config) computeSinglePass(s *Sample) { + if s.Disabled { + return + } + s.SinglePass = true + + var tlines []map[string]tokenspos + +outer: + for _, l := range s.Lines { + tp := make(map[string]tokenspos) + for j, t := range s.Tokens { + var pos tokenpos + var err error + offsets, err := t.GetReplacementOffsets(l[t.Field]) + if err != nil || len(offsets) == 0 { + log.Infof("Error getting replacements for token '%s' in event '%s', disabling SinglePass", t.Name, l[t.Field]) + s.SinglePass = false + break outer + } + for _, offset := range offsets { + pos1 := offset[0] + pos2 := offset[1] + if pos1 < 0 || pos2 < 0 { + log.Infof("Token '%s' not found in event '%s', disabling SinglePass", t.Name, l) + s.SinglePass = false + break outer + } + pos.Pos1 = pos1 + pos.Pos2 = pos2 + pos.Token = j + tp[t.Field] = append(tp[t.Field], pos) + } + } + + // Ensure we don't have any tokens overlapping one another for singlepass + for _, v := range tp { + sort.Sort(v) + + lastpos := 0 + lasttoken := "" + maxpos := 0 + for _, pos := range v { + if lastpos > pos.Pos1 { + log.Infof("Token '%s' extends beyond beginning of token '%s', disabling SinglePass", lasttoken, s.Tokens[pos.Token].Name) + s.SinglePass = false + break outer + } + if maxpos > pos.Pos1 { + log.Infof("Some former token extends beyond the beginning of token '%s', disabling SinglePass", s.Tokens[pos.Token].Name) + s.SinglePass = false + break outer + } + if pos.Pos2 > maxpos { + maxpos = pos.Pos2 + } + lastpos = pos.Pos2 + lasttoken = s.Tokens[pos.Token].Name + } + } + tlines = append(tlines, tp) + } + + if s.SinglePass { + // Break up each line and field according to the positions of the tokens + for i, line := range s.Lines { + if len(tlines) >= i && len(tlines) > 0 { + bline := make(map[string][]StringOrToken) + for field := range line { + var bfield []StringOrToken + if _, ok := tlines[i][field]; !ok { + bf := StringOrToken{T: nil, S: line[field]} + bfield = append(bfield, bf) + } else { + lastpos := 0 + for _, tp := range tlines[i][field] { + if tp.Pos1 == 0 { + bf := StringOrToken{T: &s.Tokens[tp.Token], S: ""} + bfield = append(bfield, bf) + lastpos = tp.Pos2 + } else { + bf := StringOrToken{T: nil, S: s.Lines[i][field][lastpos:tp.Pos1]} + bfield = append(bfield, bf) + bf = StringOrToken{T: &s.Tokens[tp.Token], S: ""} + bfield = append(bfield, bf) + lastpos = tp.Pos2 + } + } + if lastpos < len(s.Lines[i][field]) { + bf := StringOrToken{T: nil, S: s.Lines[i][field][lastpos:]} + bfield = append(bfield, bf) + } + } + bline[field] = bfield + } + s.BrokenLines = append(s.BrokenLines, bline) + } + } + } +} + +// setupGenerator configures the sample's generator: replay offsets for replay generators, +// or custom Lua generator linkage for non-sample generators. +func (c *Config) setupGenerator(s *Sample) { + if s.Generator == "replay" { + s.ReplayOffsets = make([]time.Duration, len(s.Lines)) + var lastts time.Time + var avgOffset time.Duration + outer2: + for i := 0; i < len(s.Lines); i++ { + inner2: + for _, t := range s.Tokens { + if t.Type == "timestamp" || t.Type == "gotimestamp" || t.Type == "epochtimestamp" { + offsets, err := t.GetReplacementOffsets(s.Lines[i][t.Field]) + if err != nil || len(offsets) == 0 { + log.WithFields(log.Fields{ + "token": t.Name, + "sample": s.Name, + "err": err, + }).Errorf("Error getting timestamp offsets, disabling sample") + s.Disabled = true + break outer2 + } + pos1 := offsets[0][0] + pos2 := offsets[0][1] + ts, err := t.ParseTimestamp(s.Lines[i][t.Field][pos1:pos2]) + if err != nil { + log.WithFields(log.Fields{ + "token": t.Name, + "sample": s.Name, + "err": err, + "event": s.Lines[0][t.Field], + }).Errorf("Error parsing timestamp, disabling sample") + s.Disabled = true + break outer2 + } + if i > 0 { + s.ReplayOffsets[i-1] = lastts.Sub(ts) * -1 + avgOffset = (avgOffset + s.ReplayOffsets[i-1]) / 2 + } + lastts = ts + break inner2 + } + } + s.ReplayOffsets[len(s.ReplayOffsets)-1] = avgOffset + } + log.WithFields(log.Fields{ + "sample": s.Name, + "ReplayOffsets": s.ReplayOffsets, + }).Debugf("ReplayOffsets values") + } else if s.Generator != "sample" { + for _, g := range c.Generators { + if g.Name == s.Generator { + s.LuaMutex = &sync.Mutex{} + s.CustomGenerator = g + if g.SingleThreaded { + s.GeneratorState = NewGeneratorState(s) + } + } + } + if s.CustomGenerator == nil { + log.Errorf("Generator '%s' not found for sample '%s', disabling sample", s.Generator, s.Name) + s.Disabled = true + } + } +} + +// validateRater returns a copy of the rater with the Options properly cast +func (c *Config) validateRater(r *RaterConfig) { + configRaterKeys := map[string]bool{ + "HourOfDay": true, + "MinuteOfHour": true, + "DayOfWeek": true, + } + + opt := make(map[string]interface{}) + for k, v := range r.Options { + var newvset interface{} + if configRaterKeys[k] { + newv := make(map[int]float64) + vcast := v.(map[interface{}]interface{}) + for k2, v2 := range vcast { + k2int := k2.(int) + v2float, ok := v2.(float64) + if !ok { + v2int, ok := v2.(int) + if !ok { + log.Fatalf("Rater value '%#v' of key '%#v' for rater '%s' in '%s' is not a float or int", v2, k2, r.Name, k) + } + v2float = float64(v2int) + } + newv[k2int] = v2float + } + newvset = newv + } else { + newvset = v + } + opt[k] = newvset + } + r.Options = opt +} diff --git a/internal/defaults.go b/internal/defaults.go index aadba83..c5135e8 100644 --- a/internal/defaults.go +++ b/internal/defaults.go @@ -19,7 +19,7 @@ const defaultOutputTemplate = "raw" const defaultGenerator = "sample" const defaultEarliest = "now" const defaultLatest = "now" -const defaultRandomizeEvents = false + const defaultRater = "default" // Default file output values @@ -47,6 +47,9 @@ const defaultOutQueueLength = 10 // This is the user facilitiy (1 << 3 == 8) at INFO (6) level, (8+6) const defaultSyslogPriority = 14 +// configExtensions defines the file extensions accepted for YAML/JSON config files +var configExtensions = map[string]bool{".yml": true, ".yaml": true, ".json": true} + var ( defaultCSVTemplate *Template defaultJSONTemplate *Template @@ -57,6 +60,15 @@ var ( defaultConfigRaterConfig *RaterConfig ) +// uniformRateMap returns a map[int]float64 with keys 0..n-1 all set to 1.0. +func uniformRateMap(n int) map[int]float64 { + m := make(map[int]float64, n) + for i := 0; i < n; i++ { + m[i] = 1.0 + } + return m +} + func init() { defaultCSVTemplate = &Template{ Name: "csv", @@ -91,103 +103,9 @@ func init() { Name: "config", Type: "config", Options: map[string]interface{}{ - "HourOfDay": map[int]float64{ - 0: 1.0, - 1: 1.0, - 2: 1.0, - 3: 1.0, - 4: 1.0, - 5: 1.0, - 6: 1.0, - 7: 1.0, - 8: 1.0, - 9: 1.0, - 10: 1.0, - 11: 1.0, - 12: 1.0, - 13: 1.0, - 14: 1.0, - 15: 1.0, - 16: 1.0, - 17: 1.0, - 18: 1.0, - 19: 1.0, - 20: 1.0, - 21: 1.0, - 22: 1.0, - 23: 1.0, - }, - "DayOfWeek": map[int]float64{ - 0: 1.0, - 1: 1.0, - 2: 1.0, - 3: 1.0, - 4: 1.0, - 5: 1.0, - 6: 1.0, - }, - "MinuteOfHour": map[int]float64{ - 0: 1.0, - 1: 1.0, - 2: 1.0, - 3: 1.0, - 4: 1.0, - 5: 1.0, - 6: 1.0, - 7: 1.0, - 8: 1.0, - 9: 1.0, - 10: 1.0, - 11: 1.0, - 12: 1.0, - 13: 1.0, - 14: 1.0, - 15: 1.0, - 16: 1.0, - 17: 1.0, - 18: 1.0, - 19: 1.0, - 20: 1.0, - 21: 1.0, - 22: 1.0, - 23: 1.0, - 24: 1.0, - 25: 1.0, - 26: 1.0, - 27: 1.0, - 28: 1.0, - 29: 1.0, - 30: 1.0, - 31: 1.0, - 32: 1.0, - 33: 1.0, - 34: 1.0, - 35: 1.0, - 36: 1.0, - 37: 1.0, - 38: 1.0, - 39: 1.0, - 40: 1.0, - 41: 1.0, - 42: 1.0, - 43: 1.0, - 44: 1.0, - 45: 1.0, - 46: 1.0, - 47: 1.0, - 48: 1.0, - 49: 1.0, - 50: 1.0, - 51: 1.0, - 52: 1.0, - 53: 1.0, - 54: 1.0, - 55: 1.0, - 56: 1.0, - 57: 1.0, - 58: 1.0, - 59: 1.0, - }, + "HourOfDay": uniformRateMap(24), + "DayOfWeek": uniformRateMap(7), + "MinuteOfHour": uniformRateMap(60), }, } } diff --git a/internal/errors.go b/internal/errors.go new file mode 100644 index 0000000..7f9496c --- /dev/null +++ b/internal/errors.go @@ -0,0 +1,19 @@ +package internal + +import "fmt" + +// HTTPError represents an HTTP response with a non-2xx status code. +type HTTPError struct { + StatusCode int + URL string + Body string +} + +func (e *HTTPError) Error() string { + return fmt.Sprintf("HTTP %d from %s: %s", e.StatusCode, e.URL, e.Body) +} + +// IsNotFound returns true if the HTTP status code is 404. +func (e *HTTPError) IsNotFound() bool { + return e.StatusCode == 404 +} diff --git a/internal/github.go b/internal/github.go index 6d210fb..f36b51c 100644 --- a/internal/github.go +++ b/internal/github.go @@ -3,7 +3,7 @@ package internal // Mostly from https://jacobmartins.com/2016/02/29/getting-started-with-oauth2-in-go/ import ( - "io/ioutil" + "context" "net/http" "os" "path/filepath" @@ -55,7 +55,7 @@ func NewGitHub(requireauth bool) *GitHub { tokenFile := filepath.Join(os.ExpandEnv("$GOGEN_HOME"), ".githubtoken") _, err := os.Stat(tokenFile) if err == nil { - buf, err := ioutil.ReadFile(tokenFile) + buf, err := os.ReadFile(tokenFile) if err != nil { log.Fatalf("Error reading from file %s: %s", tokenFile, err) } @@ -72,7 +72,7 @@ func NewGitHub(requireauth bool) *GitHub { <-gh.done log.Debugf("Getting GitHub token '%s' from oauth", gh.token) - err = ioutil.WriteFile(tokenFile, []byte(gh.token), 400) + err = os.WriteFile(tokenFile, []byte(gh.token), 400) if err != nil { log.Fatalf("Error writing token to file %s: %s", tokenFile, err) } @@ -81,7 +81,7 @@ func NewGitHub(requireauth bool) *GitHub { ts := oauth2.StaticTokenSource( &oauth2.Token{AccessToken: gh.token}, ) - tc := oauth2.NewClient(oauth2.NoContext, ts) + tc := oauth2.NewClient(context.Background(), ts) gh.client = github.NewClient(tc) } else { gh.client = github.NewClient(nil) @@ -103,7 +103,7 @@ func (gh *GitHub) handleGitHubCallback(w http.ResponseWriter, r *http.Request) { } code := r.FormValue("code") - token, err := oauthConf.Exchange(oauth2.NoContext, code) + token, err := oauthConf.Exchange(context.Background(), code) if err != nil { log.Errorf("Code exchange failed with '%s'\n", err) http.Redirect(w, r, "/", http.StatusTemporaryRedirect) diff --git a/internal/gogen.go b/internal/gogen.go index 94029fb..5da8f19 100644 --- a/internal/gogen.go +++ b/internal/gogen.go @@ -3,11 +3,13 @@ package internal import ( "bytes" "encoding/json" + "errors" "fmt" - "io/ioutil" + "io" "net/http" "net/url" "os" + "time" log "github.com/coccyx/gogen/logger" "github.com/kr/pretty" @@ -32,36 +34,77 @@ type GogenList struct { Description string } +// defaultAPIClient is the shared HTTP client for API calls with a reasonable timeout. +var defaultAPIClient = &http.Client{Timeout: 30 * time.Second} + +// doHTTPRequest executes an HTTP request, reads the response body, and returns +// the body bytes. It properly closes resp.Body and returns an *HTTPError for +// non-2xx status codes. +func doHTTPRequest(client *http.Client, req *http.Request) ([]byte, error) { + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request to %s failed: %w", req.URL, err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body from %s: %w", req.URL, err) + } + + if resp.StatusCode < 200 || resp.StatusCode > 299 { + return nil, &HTTPError{ + StatusCode: resp.StatusCode, + URL: req.URL.String(), + Body: string(body), + } + } + return body, nil +} + +// doGet performs an HTTP GET request to the given URL using the default API client. +func doGet(url string) ([]byte, error) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, fmt.Errorf("error creating request for %s: %w", url, err) + } + return doHTTPRequest(defaultAPIClient, req) +} + +// doPost performs an HTTP POST request to the given URL using the default API client. +func doPost(url string, body io.Reader, headers map[string]string) ([]byte, error) { + req, err := http.NewRequest("POST", url, body) + if err != nil { + return nil, fmt.Errorf("error creating request for %s: %w", url, err) + } + for k, v := range headers { + req.Header.Add(k, v) + } + return doHTTPRequest(defaultAPIClient, req) +} + // List calls /v1/list -func List() []GogenList { +func List() ([]GogenList, error) { return listsearch(fmt.Sprintf("%s/v1/list", getAPIURL())) } // Search calls /v1/search -func Search(q string) []GogenList { +func Search(q string) ([]GogenList, error) { return listsearch(fmt.Sprintf("%s/v1/search?q=%s", getAPIURL(), url.QueryEscape(q))) } -func listsearch(url string) (ret []GogenList) { - client := &http.Client{} - resp, err := client.Get(url) - if err != nil || resp.StatusCode != 200 { - if resp.StatusCode != 200 { - body, _ := ioutil.ReadAll(resp.Body) - log.Fatalf("Non 200 response code searching for Gogen: %s", string(body)) - } else { - log.Fatalf("Error retrieving list of Gogens: %s", err) - } - } - body, err := ioutil.ReadAll(resp.Body) +func listsearch(url string) ([]GogenList, error) { + body, err := doGet(url) if err != nil { - log.Fatalf("Error reading body from response: %s", err) + return nil, fmt.Errorf("error retrieving list of Gogens: %w", err) } var list map[string]interface{} err = json.Unmarshal(body, &list) - // log.Debugf("List body: %s", string(body)) - // log.Debugf("list: %s", fmt.Sprintf("%# v", pretty.Formatter(list))) + if err != nil { + return nil, fmt.Errorf("error unmarshaling list response: %w", err) + } items := list["Items"].([]interface{}) + var ret []GogenList for _, item := range items { tempitem := item.(map[string]interface{}) if _, ok := tempitem["gogen"]; !ok { @@ -74,47 +117,33 @@ func listsearch(url string) (ret []GogenList) { ret = append(ret, li) } log.Debugf("List: %# v", pretty.Formatter(ret)) - return ret + return ret, nil } // Get calls /v1/get var Get = func(q string) (g GogenInfo, err error) { - client := &http.Client{} url := fmt.Sprintf("%s/v1/get/%s", getAPIURL(), q) log.Debugf("Calling %s", url) - resp, err := client.Get(url) - if err != nil || resp.StatusCode != 200 { - if resp != nil { - if resp.StatusCode == 404 { - return g, fmt.Errorf("Could not find Gogen: %s\n", q) - } - if resp.StatusCode != 200 { - body, _ := ioutil.ReadAll(resp.Body) - return g, fmt.Errorf("Non 200 response code retrieving Gogen: %s", string(body)) - } - } else { - return g, fmt.Errorf("Error retrieving Gogen %s: %s", q, err) - } - } - body, err := ioutil.ReadAll(resp.Body) + body, err := doGet(url) if err != nil { - return g, fmt.Errorf("Error reading body from response: %s", err) + var httpErr *HTTPError + if errors.As(err, &httpErr) && httpErr.IsNotFound() { + return g, fmt.Errorf("could not find Gogen %s: %w", q, err) + } + return g, fmt.Errorf("error retrieving Gogen %s: %w", q, err) } - // log.Debugf("Body: %s", body) var gogen map[string]interface{} err = json.Unmarshal(body, &gogen) if err != nil { - return g, fmt.Errorf("Error unmarshaling body: %s", err) + return g, fmt.Errorf("error unmarshaling body: %w", err) } - // log.Debugf("gogen: %# v", pretty.Formatter(gogen)) tmp, err := json.Marshal(gogen["Item"]) if err != nil { - return g, fmt.Errorf("Error converting Item to JSON: %s", err) + return g, fmt.Errorf("error converting Item to JSON: %w", err) } - // log.Debugf("tmp: %s", string(tmp)) err = json.Unmarshal(tmp, &g) if err != nil { - return g, fmt.Errorf("Error unmarshaling item: %s", err) + return g, fmt.Errorf("error unmarshaling item: %w", err) } gCopy := g gCopy.Config = "redacted" @@ -123,38 +152,26 @@ var Get = func(q string) (g GogenInfo, err error) { } // Upsert calls /v1/upsert -func Upsert(g GogenInfo) { +func Upsert(g GogenInfo) error { gh := NewGitHub(true) - upsert(g, gh) + return upsert(g, gh) } -func upsert(g GogenInfo, gh *GitHub) { - client := &http.Client{} - +func upsert(g GogenInfo, gh *GitHub) error { b, err := json.Marshal(g) if err != nil { - log.Fatalf("Error marshaling Gogen %#v: %s", g, err) + return fmt.Errorf("error marshaling Gogen %#v: %w", g, err) } - // log.Debugf("Body: %s", string(b)) - req, _ := http.NewRequest("POST", fmt.Sprintf("%s/v1/upsert", getAPIURL()), bytes.NewReader(b)) - // Still need GitHub token for authorization to verify user identity - req.Header.Add("Authorization", "token "+gh.token) - resp, err := client.Do(req) - if err != nil || resp.StatusCode != 200 { - if resp.StatusCode != 200 { - body, _ := ioutil.ReadAll(resp.Body) - log.Fatalf("Non 200 response code Upserting: %s", string(body)) - } else { - log.Fatalf("Error POSTing to upsert: %s", err) - } + headers := map[string]string{ + "Authorization": "token " + gh.token, + } + _, err = doPost(fmt.Sprintf("%s/v1/upsert", getAPIURL()), bytes.NewReader(b), headers) + if err != nil { + return fmt.Errorf("error upserting Gogen: %w", err) } - // body, err := ioutil.ReadAll(resp.Body) - // if err != nil { - // log.Fatalf("Error reading response body: %s", err) - // } - // log.Debugf("Response Body: %s", body) log.Debugf("Upserted: %# v", pretty.Formatter(g)) + return nil } // getAPIURL returns the API URL from environment variable or default value diff --git a/internal/gogen_test.go b/internal/gogen_test.go index dbf287e..6dd7df6 100644 --- a/internal/gogen_test.go +++ b/internal/gogen_test.go @@ -2,7 +2,7 @@ package internal import ( "encoding/json" - "io/ioutil" + "io" "net/http" "net/http/httptest" "os" @@ -119,7 +119,7 @@ func mockGogenServer(t *testing.T) (*httptest.Server, []GogenList) { } // Read the request body - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, "Error reading request body", http.StatusBadRequest) return @@ -163,7 +163,8 @@ func TestListWithMockServer(t *testing.T) { defer os.Setenv("GOGEN_APIURL", originalAPIURL) // Restore the original value when done // Call the List function - result := List() + result, err := List() + assert.NoError(t, err, "List should not return an error") // Verify the results assert.NotEmpty(t, result, "List result should not be empty") @@ -233,7 +234,8 @@ func TestSearchWithMockServer(t *testing.T) { for _, tc := range testCases { t.Run("Search_"+tc.query, func(t *testing.T) { // Call the Search function - result := Search(tc.query) + result, err := Search(tc.query) + assert.NoError(t, err, "Search should not return an error for query: %s", tc.query) // Verify the results assert.NotEmpty(t, result, "Search result should not be empty for query: %s", tc.query) @@ -290,7 +292,7 @@ func TestGetWithMockServer(t *testing.T) { // Test with an invalid Gogen ID _, err = Get("coccyx/nonexistent") assert.Error(t, err, "Get should return an error for an invalid Gogen ID") - assert.Contains(t, err.Error(), "Could not find Gogen", "Error message should indicate Gogen not found") + assert.Contains(t, err.Error(), "could not find Gogen", "Error message should indicate Gogen not found") } // Mock the GitHub struct for testing Upsert @@ -329,11 +331,9 @@ func TestUpsertWithMockServer(t *testing.T) { Config: "", } - // Call the Upsert function - upsert(gogen, &GitHub{token: "mock-github-token"}) - - // Since Upsert doesn't return anything, we can only verify that it didn't panic - // The mock server will return an error if the request is not formatted correctly + // Call the upsert function + err := upsert(gogen, &GitHub{token: "mock-github-token"}) + assert.NoError(t, err, "upsert should not return an error") } // Helper function to filter expected items based on a query string diff --git a/internal/sample.go b/internal/sample.go index ff3947c..f4382fc 100644 --- a/internal/sample.go +++ b/internal/sample.go @@ -267,40 +267,39 @@ func (t Token) GenReplacement(choice int, et time.Time, lt time.Time, now time.T f := float64(randgen.Intn(upper-lower)+lower) / math.Pow10(t.Precision) return strconv.FormatFloat(f, 'f', t.Precision, 64), -1, nil case "string", "hex": - var ret string + var b strings.Builder + b.Grow(t.Length) + letters := randStringLetters + if t.Replacement == "hex" { + letters = randHexLetters + } for i := 0; i < t.Length; i++ { - if t.Replacement == "string" { - ri := randgen.Intn(len(randStringLetters)) - ret += randStringLetters[ri : ri+1] - } else { - ri := randgen.Intn(len(randHexLetters)) - ret += randHexLetters[ri : ri+1] - } + b.WriteByte(letters[randgen.Intn(len(letters))]) } - return ret, -1, nil + return b.String(), -1, nil case "guid": u := uuid.NewV4() return u.String(), -1, nil case "ipv4": - var ret string + var b strings.Builder + b.Grow(15) // max "255.255.255.255" for i := 0; i < 4; i++ { - ri := randgen.Intn(255) - ret += strconv.Itoa(ri) - if i < 3 { - ret += "." + if i > 0 { + b.WriteByte('.') } + b.WriteString(strconv.Itoa(randgen.Intn(255))) } - return ret, -1, nil + return b.String(), -1, nil case "ipv6": - var ret string + var b strings.Builder + b.Grow(39) // max "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff" for i := 0; i < 8; i++ { - ri := randgen.Intn(65535) - ret += fmt.Sprintf("%x", ri) - if i < 7 { - ret += ":" + if i > 0 { + b.WriteByte(':') } + fmt.Fprintf(&b, "%x", randgen.Intn(65535)) } - return ret, -1, nil + return b.String(), -1, nil } case "choice": if choice == -1 { diff --git a/internal/sample_test.go b/internal/sample_test.go index b05e008..2f66e89 100644 --- a/internal/sample_test.go +++ b/internal/sample_test.go @@ -230,6 +230,136 @@ func benchmarkToken(conf string, i int, b *testing.B) { } } +// mockRater implements Rater for testing rated tokens +type mockRater struct { + rate float64 +} + +func (m *mockRater) EventRate(s *Sample, now time.Time, count int) float64 { return m.rate } +func (m *mockRater) TokenRate(t Token, now time.Time) float64 { return m.rate } + +func TestGenReplacementRatedInt(t *testing.T) { + source := rand.NewSource(0) + randgen := rand.New(source) + fullevent := make(map[string]string) + now := time.Now() + + token := Token{ + Name: "ratedint", + Type: "rated", + Replacement: "int", + Lower: 10, + Upper: 20, + Rater: &mockRater{rate: 2.0}, + } + result, _, err := token.GenReplacement(-1, now, now, now, randgen, fullevent) + assert.NoError(t, err) + // With rate=2.0, value should be roughly doubled + val, _ := fmt.Sscanf(result, "%d", new(int)) + assert.Equal(t, 1, val, "should parse as an integer") +} + +func TestGenReplacementRatedIntEqualBounds(t *testing.T) { + source := rand.NewSource(0) + randgen := rand.New(source) + fullevent := make(map[string]string) + now := time.Now() + + token := Token{ + Name: "ratedintequal", + Type: "rated", + Replacement: "int", + Lower: 5, + Upper: 5, + Rater: &mockRater{rate: 1.0}, + } + result, _, err := token.GenReplacement(-1, now, now, now, randgen, fullevent) + assert.NoError(t, err) + assert.Equal(t, "5", result) +} + +func TestGenReplacementRatedFloat(t *testing.T) { + source := rand.NewSource(0) + randgen := rand.New(source) + fullevent := make(map[string]string) + now := time.Now() + + token := Token{ + Name: "ratedfloat", + Type: "rated", + Replacement: "float", + Lower: 1, + Upper: 10, + Precision: 2, + Rater: &mockRater{rate: 1.5}, + } + result, _, err := token.GenReplacement(-1, now, now, now, randgen, fullevent) + assert.NoError(t, err) + // Should be a float string with 2 decimal places + assert.Contains(t, result, ".") +} + +func TestGenReplacementRatedFloatEqualBounds(t *testing.T) { + source := rand.NewSource(0) + randgen := rand.New(source) + fullevent := make(map[string]string) + now := time.Now() + + token := Token{ + Name: "ratedfloatequal", + Type: "rated", + Replacement: "float", + Lower: 5, + Upper: 5, + Precision: 2, + Rater: &mockRater{rate: 1.0}, + } + result, _, err := token.GenReplacement(-1, now, now, now, randgen, fullevent) + assert.NoError(t, err) + assert.Equal(t, "5.00", result) +} + +func TestGenReplacementFieldChoice(t *testing.T) { + source := rand.NewSource(0) + randgen := rand.New(source) + fullevent := make(map[string]string) + now := time.Now() + + token := Token{ + Name: "fieldchoice", + Type: "fieldChoice", + SrcField: "city", + FieldChoice: []map[string]string{ + {"city": "NYC", "state": "NY"}, + {"city": "LA", "state": "CA"}, + }, + } + result, choice, err := token.GenReplacement(-1, now, now, now, randgen, fullevent) + assert.NoError(t, err) + assert.True(t, result == "NYC" || result == "LA") + assert.True(t, choice >= 0) + + // Test with specific choice + result, _, err = token.GenReplacement(0, now, now, now, randgen, fullevent) + assert.NoError(t, err) + assert.Equal(t, "NYC", result) +} + +func TestGenReplacementInvalidType(t *testing.T) { + source := rand.NewSource(0) + randgen := rand.New(source) + fullevent := make(map[string]string) + now := time.Now() + + token := Token{ + Name: "badtype", + Type: "nonexistenttype", + } + _, _, err := token.GenReplacement(-1, now, now, now, randgen, fullevent) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid type") +} + func BenchmarkReplacement(b *testing.B) { os.Setenv("GOGEN_HOME", "..") os.Setenv("GOGEN_ALWAYS_REFRESH", "1") diff --git a/internal/share.go b/internal/share.go index 08e72f1..101f55e 100644 --- a/internal/share.go +++ b/internal/share.go @@ -4,7 +4,6 @@ import ( "context" "encoding/csv" "fmt" - "io/ioutil" "net/url" "os" "path/filepath" @@ -37,8 +36,7 @@ func Push(name string, run Run) string { // Push all file based mixes for i := range ec.Mix { m := ec.Mix[i] - acceptableExtensions := map[string]bool{".yml": true, ".yaml": true, ".json": true} - if _, ok := acceptableExtensions[filepath.Ext(m.Sample)]; ok { + if _, ok := configExtensions[filepath.Ext(m.Sample)]; ok { sc := BuildConfig(ConfigConfig{ FullConfig: m.Sample, Export: true, @@ -101,7 +99,9 @@ func push(name string, genc *Config, pushc *Config, run Run) string { Version: version, Config: string(configYaml), } - Upsert(g) + if err := Upsert(g); err != nil { + log.Fatalf("Error upserting Gogen: %s", err) + } return *user.Login } @@ -129,7 +129,7 @@ func Pull(gogen string, dir string, deconstruct bool) { // Write the config to a file filename := filepath.Join(dir, name+".yml") - err = ioutil.WriteFile(filename, []byte(g.Config), 0644) + err = os.WriteFile(filename, []byte(g.Config), 0644) if err != nil { log.Fatalf("Error writing to file %s: %s", filename, err) } @@ -159,7 +159,7 @@ func PullFile(gogen string, filename string) { versionCacheFile := filepath.Join(os.ExpandEnv("$GOGEN_TMPDIR"), ".versioncache_"+url.QueryEscape(gogen)) _, err = os.Stat(versionCacheFile) if err == nil { - versionBytes, err := ioutil.ReadFile(versionCacheFile) + versionBytes, err := os.ReadFile(versionCacheFile) if err != nil { log.Fatalf("Error reading version cache file '%s': %s", versionCacheFile, err) } @@ -169,7 +169,7 @@ func PullFile(gogen string, filename string) { } if version == g.Version { log.Debugf("Reading config from cache file '%s'", cacheFile) - configContent, err = ioutil.ReadFile(cacheFile) + configContent, err = os.ReadFile(cacheFile) if err != nil { cached = false } else { @@ -291,7 +291,7 @@ func deconstructConfig(filename string, name string, dir string) { } outfname := filepath.Join(samplesDir, name+".yml") log.Debugf("Writing sample file for sammple '%s' at file: %s", s.Name, outfname) - err = ioutil.WriteFile(outfname, outb, 0644) + err = os.WriteFile(outfname, outb, 0644) if err != nil { log.Fatalf("Cannot write file %s: %s", outfname, err) } @@ -305,7 +305,7 @@ func deconstructConfig(filename string, name string, dir string) { if outb, err = yaml.Marshal(t); err != nil { log.Fatalf("Cannot Marshal template '%s', err: %s", t.Name, err) } - err = ioutil.WriteFile(filepath.Join(templatesDir, t.Name+".yml"), outb, 0644) + err = os.WriteFile(filepath.Join(templatesDir, t.Name+".yml"), outb, 0644) if err != nil { log.Fatalf("Error writing file %s", filepath.Join(templatesDir, t.Name+".yml")) } @@ -314,7 +314,7 @@ func deconstructConfig(filename string, name string, dir string) { for i, g := range c.Generators { if g.FileName != "" { fname := filepath.Base(g.FileName) - err = ioutil.WriteFile(filepath.Join(generatorsDir, fname), []byte(g.Script), 0644) + err = os.WriteFile(filepath.Join(generatorsDir, fname), []byte(g.Script), 0644) if err != nil { log.Fatalf("Error writing file %s", filepath.Join(generatorsDir, fname)) } @@ -327,7 +327,7 @@ func deconstructConfig(filename string, name string, dir string) { if outb, err = yaml.Marshal(g); err != nil { log.Fatalf("Cannot Marshal generator '%s', err: %s", g.Name, err) } - err = ioutil.WriteFile(filepath.Join(generatorsDir, g.Name+".yml"), outb, 0644) + err = os.WriteFile(filepath.Join(generatorsDir, g.Name+".yml"), outb, 0644) if err != nil { log.Fatalf("Error writing file %s", filepath.Join(generatorsDir, g.Name+".yml")) } diff --git a/internal/share_test.go b/internal/share_test.go index 1793336..188e974 100644 --- a/internal/share_test.go +++ b/internal/share_test.go @@ -146,3 +146,251 @@ func TestSharePullFile(t *testing.T) { _, err = os.Stat(filepath.Join(os.ExpandEnv("$GOGEN_TMPDIR"), ".configcache_testuser%2Ftestconfig")) assert.NoError(t, err, "Couldn't find cache file") } + +func TestSharePullShortName(t *testing.T) { + // Test Pull with a short name (no "/" in gogen string) + originalGet := Get + defer func() { Get = originalGet }() + + Get = func(q string) (GogenInfo, error) { + return GogenInfo{ + Gogen: "shortname", + Name: "shortname", + Description: "Short name test", + Owner: "testuser", + Version: 1, + Config: "sample: test\nname: shortname", + }, nil + } + + os.Setenv("GOGEN_HOME", "..") + dir := t.TempDir() + + Pull("shortname", dir, false) + _, err := os.Stat(filepath.Join(dir, "shortname.yml")) + assert.NoError(t, err, "Couldn't find file shortname.yml") +} + +func TestSharePullFileCached(t *testing.T) { + // Test PullFile when cache exists and version matches → uses cached content + originalGet := Get + defer func() { Get = originalGet }() + + Get = func(q string) (GogenInfo, error) { + return GogenInfo{ + Gogen: "testuser/cached", + Name: "cached", + Owner: "testuser", + Version: 5, + Config: "should not be used", + }, nil + } + + tmpdir := t.TempDir() + os.Setenv("GOGEN_TMPDIR", tmpdir) + defer os.Unsetenv("GOGEN_TMPDIR") + + // Pre-create version cache with matching version + versionCacheFile := filepath.Join(tmpdir, ".versioncache_testuser%2Fcached") + os.WriteFile(versionCacheFile, []byte("5"), 0644) + + // Pre-create config cache with different content + cacheFile := filepath.Join(tmpdir, ".configcache_testuser%2Fcached") + os.WriteFile(cacheFile, []byte("cached config content"), 0644) + + outFile := filepath.Join(tmpdir, "output.yml") + PullFile("testuser/cached", outFile) + + // Should use cached content, not the API response + data, err := os.ReadFile(outFile) + assert.NoError(t, err) + assert.Equal(t, "cached config content", string(data)) +} + +func TestSharePullFileVersionMismatch(t *testing.T) { + // Test PullFile when cache version doesn't match → uses API content and updates cache + originalGet := Get + defer func() { Get = originalGet }() + + Get = func(q string) (GogenInfo, error) { + return GogenInfo{ + Gogen: "testuser/mismatch", + Name: "mismatch", + Owner: "testuser", + Version: 10, + Config: "fresh api content", + }, nil + } + + tmpdir := t.TempDir() + os.Setenv("GOGEN_TMPDIR", tmpdir) + defer os.Unsetenv("GOGEN_TMPDIR") + + // Pre-create version cache with OLD version + versionCacheFile := filepath.Join(tmpdir, ".versioncache_testuser%2Fmismatch") + os.WriteFile(versionCacheFile, []byte("5"), 0644) + + // Pre-create config cache with old content + cacheFile := filepath.Join(tmpdir, ".configcache_testuser%2Fmismatch") + os.WriteFile(cacheFile, []byte("old cached content"), 0644) + + outFile := filepath.Join(tmpdir, "output.yml") + PullFile("testuser/mismatch", outFile) + + // Should use API content since version doesn't match + data, err := os.ReadFile(outFile) + assert.NoError(t, err) + assert.Equal(t, "fresh api content", string(data)) + + // Cache files should be updated + versionData, _ := os.ReadFile(versionCacheFile) + assert.Equal(t, "10", string(versionData)) + + cachedData, _ := os.ReadFile(cacheFile) + assert.Equal(t, "fresh api content", string(cachedData)) +} + +func TestSharePullWithDeconstructCSV(t *testing.T) { + // Test deconstructConfig with CSV fieldChoice tokens + originalGet := Get + defer func() { Get = originalGet }() + + configYaml := ` +global: + rotInterval: 1 + output: + outputter: devnull + outputTemplate: raw +samples: + - name: csvtest + description: CSV deconstruct test + interval: 1 + count: 1 + endIntervals: 1 + tokens: + - name: city + format: template + type: fieldChoice + field: _raw + srcField: city + sample: markets.csv + fieldChoice: + - city: NYC + state: NY + - city: LA + state: CA + lines: + - _raw: city=$city$ +` + + Get = func(q string) (GogenInfo, error) { + return GogenInfo{ + Gogen: "testuser/csvtest", + Name: "csvtest", + Owner: "testuser", + Version: 1, + Config: configYaml, + }, nil + } + + os.Setenv("GOGEN_HOME", "..") + dir := t.TempDir() + + Pull("testuser/csvtest", dir, true) + _, err := os.Stat(filepath.Join(dir, "samples", "markets.csv")) + assert.NoError(t, err, "Couldn't find samples/markets.csv") + _, err = os.Stat(filepath.Join(dir, "samples", "csvtest.yml")) + assert.NoError(t, err, "Couldn't find samples/csvtest.yml") +} + +func TestSharePullWithDeconstructGenerator(t *testing.T) { + // Test deconstructConfig with generator that has a fileName + originalGet := Get + defer func() { Get = originalGet }() + + configYaml := ` +global: + rotInterval: 1 + output: + outputter: devnull + outputTemplate: raw +samples: + - name: gentest + description: Generator deconstruct test + generator: mygen + interval: 1 + count: 1 + endIntervals: 1 + lines: + - _raw: test event +generators: + - name: mygen + fileName: /path/to/mygen.lua + script: | + lines = getLines() + return send(lines) +` + + Get = func(q string) (GogenInfo, error) { + return GogenInfo{ + Gogen: "testuser/gentest", + Name: "gentest", + Owner: "testuser", + Version: 1, + Config: configYaml, + }, nil + } + + os.Setenv("GOGEN_HOME", "..") + dir := t.TempDir() + + Pull("testuser/gentest", dir, true) + _, err := os.Stat(filepath.Join(dir, "generators", "mygen.lua")) + assert.NoError(t, err, "Couldn't find generators/mygen.lua") + _, err = os.Stat(filepath.Join(dir, "generators", "mygen.yml")) + assert.NoError(t, err, "Couldn't find generators/mygen.yml") +} + +func TestSharePullWithDeconstructTemplates(t *testing.T) { + // Test deconstructConfig with templates + originalGet := Get + defer func() { Get = originalGet }() + + configYaml := ` +global: + rotInterval: 1 + output: + outputter: devnull + outputTemplate: mytemplate +samples: + - name: tmpltest + description: Template deconstruct test + interval: 1 + count: 1 + endIntervals: 1 + lines: + - _raw: test event +templates: + - name: mytemplate + header: "HEADER\n" + row: "{{._raw}}\n" + footer: "FOOTER\n" +` + + Get = func(q string) (GogenInfo, error) { + return GogenInfo{ + Gogen: "testuser/tmpltest", + Name: "tmpltest", + Owner: "testuser", + Version: 1, + Config: configYaml, + }, nil + } + + os.Setenv("GOGEN_HOME", "..") + dir := t.TempDir() + + Pull("testuser/tmpltest", dir, true) + _, err := os.Stat(filepath.Join(dir, "templates", "mytemplate.yml")) + assert.NoError(t, err, "Couldn't find templates/mytemplate.yml") +} diff --git a/logger/logger.go b/logger/logger.go index c713fb8..cd05606 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -1,4 +1,4 @@ -package logging +package logger import ( "os" diff --git a/logger/logger_test.go b/logger/logger_test.go index e2a9dc5..b9e137a 100644 --- a/logger/logger_test.go +++ b/logger/logger_test.go @@ -1,4 +1,4 @@ -package logging +package logger import ( "bytes" diff --git a/main.go b/main.go index 7b9690d..134e6c1 100644 --- a/main.go +++ b/main.go @@ -87,7 +87,8 @@ func Setup(clic *cli.Context) { } if len(clic.String("config")) > 0 { cstr := clic.String("config") - if cstr[0:4] == "http" || cstr[len(cstr)-3:] == "yml" || cstr[len(cstr)-4:] == "yaml" || cstr[len(cstr)-4:] == "json" { + ext := filepath.Ext(cstr) + if strings.HasPrefix(cstr, "http") || ext == ".yml" || ext == ".yaml" || ext == ".json" { os.Setenv("GOGEN_FULLCONFIG", cstr) } else { config.PullFile(cstr, filepath.Join(os.ExpandEnv("$GOGEN_TMPDIR"), ".config.yml")) @@ -107,7 +108,7 @@ func Setup(clic *cli.Context) { c.Global.GeneratorWorkers = clic.Int("generators") } if clic.Int("outputters") > 0 { - log.Infof("Setting generators to %d", clic.Int("outputters")) + log.Infof("Setting outputters to %d", clic.Int("outputters")) c.Global.OutputWorkers = clic.Int("outputters") } if clic.Bool("addTime") { @@ -123,6 +124,21 @@ func Setup(clic *cli.Context) { c.Global.CacheIntervals = 2147483647 } + applySampleOutputOverrides(c, clic) + + // Must call from runtime in case we are overriding AddTime or Facility from command line + c.SetupSystemTokens() + + // log.Debugf("Global: %#v", c.Global) + // log.Debugf("Default Tokens: %#v", c.DefaultTokens) + // log.Debugf("Default Sample: %#v", c.DefaultSample) + // log.Debugf("Samples: %#v", c.Samples) + // log.Debugf("Pretty Values %# v\n", pretty.Formatter(c)) + // j, _ := json.MarshalIndent(c, "", " ") + // log.Debugf("JSON Config: %s\n", j) +} + +func applySampleOutputOverrides(c *config.Config, clic *cli.Context) { for i := 0; i < len(c.Samples); i++ { if len(clic.String("outputter")) > 0 { log.Infof("Setting outputter to '%s'", clic.String("outputter")) @@ -161,17 +177,6 @@ func Setup(clic *cli.Context) { c.Samples[i].Output.BufferBytes = clic.Int("bufferBytes") } } - - // Must call from runtime in case we are overriding AddTime or Facility from command line - c.SetupSystemTokens() - - // log.Debugf("Global: %#v", c.Global) - // log.Debugf("Default Tokens: %#v", c.DefaultTokens) - // log.Debugf("Default Sample: %#v", c.DefaultSample) - // log.Debugf("Samples: %#v", c.Samples) - // log.Debugf("Pretty Values %# v\n", pretty.Formatter(c)) - // j, _ := json.MarshalIndent(c, "", " ") - // log.Debugf("JSON Config: %s\n", j) } func table(l []config.GogenList) { @@ -365,7 +370,10 @@ func main() { Usage: "List all published Gogens", Action: func(clic *cli.Context) error { fmt.Printf("Showing all Gogens:\n\n") - l := config.List() + l, err := config.List() + if err != nil { + log.Fatalf("Error listing Gogens: %s", err) + } table(l) return nil }, @@ -380,7 +388,10 @@ func main() { } q = strings.TrimRight(q, " ") fmt.Printf("Returning results for search: \"%s\"\n\n", q) - l := config.Search(q) + l, err := config.Search(q) + if err != nil { + log.Fatalf("Error searching Gogens: %s", err) + } if len(l) > 0 { table(l) } else { diff --git a/outputter/devnull.go b/outputter/devnull.go index c359a89..2cff6a6 100644 --- a/outputter/devnull.go +++ b/outputter/devnull.go @@ -2,7 +2,6 @@ package outputter import ( "io" - "io/ioutil" config "github.com/coccyx/gogen/internal" ) @@ -10,7 +9,7 @@ import ( type devnull struct{} func (foo devnull) Send(item *config.OutQueueItem) error { - _, err := io.Copy(ioutil.Discard, item.IO.R) + _, err := io.Copy(io.Discard, item.IO.R) return err } diff --git a/outputter/file.go b/outputter/file.go index 1a19df0..24a6006 100644 --- a/outputter/file.go +++ b/outputter/file.go @@ -18,7 +18,7 @@ type file struct { } func (f *file) Send(item *config.OutQueueItem) error { - if f.initialized == false { + if !f.initialized { info, err := os.Stat(item.S.Output.FileName) // File doesn't exist, so create if os.IsNotExist(err) { diff --git a/outputter/http.go b/outputter/http.go index 7fc5f63..aef929d 100644 --- a/outputter/http.go +++ b/outputter/http.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "fmt" "io" - "io/ioutil" "math/rand" "net/http" @@ -24,7 +23,7 @@ type httpout struct { } func (h *httpout) Send(item *config.OutQueueItem) error { - if h.initialized == false { + if !h.initialized { tr := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, DisableKeepAlives: true, MaxIdleConnsPerHost: -1} h.client = &http.Client{Transport: tr, Timeout: item.S.Output.Timeout} h.buf = bytes.NewBuffer([]byte{}) @@ -56,13 +55,13 @@ func (h *httpout) flush() error { if err != nil && h.resp == nil { return fmt.Errorf("Error making request from sample '%s' to endpoint '%s': %s", h.lastSampleName, h.endpoint, err) } - body, err := ioutil.ReadAll(h.resp.Body) + defer h.resp.Body.Close() + body, err := io.ReadAll(h.resp.Body) if err != nil { return fmt.Errorf("Error making request from sample '%s' to endpoint '%s': %s", h.lastSampleName, h.endpoint, err) } else if h.resp.StatusCode < 200 || h.resp.StatusCode > 299 { return fmt.Errorf("Error making request from sample '%s' to endpoint '%s', status '%d': %s", h.lastSampleName, h.endpoint, h.resp.StatusCode, body) } - h.resp.Body.Close() h.buf.Reset() return nil } diff --git a/outputter/kafka.go b/outputter/kafka.go index 7489fbb..2fac8b0 100644 --- a/outputter/kafka.go +++ b/outputter/kafka.go @@ -19,7 +19,7 @@ type kafkaout struct { } func (k *kafkaout) Send(item *config.OutQueueItem) error { - if k.initialized == false { + if !k.initialized { var err error if len(item.S.Output.Endpoints) < 1 { return fmt.Errorf("No configured brokers") diff --git a/outputter/network.go b/outputter/network.go index ebb890d..eff6012 100644 --- a/outputter/network.go +++ b/outputter/network.go @@ -14,7 +14,7 @@ type network struct { } func (n *network) Send(item *config.OutQueueItem) error { - if n.initialized == false { + if !n.initialized { endpoint := item.S.Output.Endpoints[rand.Intn(len(item.S.Output.Endpoints))] conn, err := net.DialTimeout(item.S.Output.Protocol, endpoint, item.S.Output.Timeout) if err != nil { diff --git a/outputter/outputter.go b/outputter/outputter.go index d580660..88de8ea 100644 --- a/outputter/outputter.go +++ b/outputter/outputter.go @@ -20,6 +20,7 @@ var ( Mutex sync.RWMutex lastTS time.Time rotchan chan *config.OutputStats + rotOnce sync.Once rotwg sync.WaitGroup gout [config.MaxOutputThreads]config.Outputter lasterr [config.MaxOutputThreads]lastError @@ -40,13 +41,22 @@ func init() { cacheBufs = make(map[string]*bytes.Buffer) } +// InitROT initializes the ROT channel and readStats goroutine. Safe to call +// multiple times; initialization only happens once until ReadFinal resets it. +// Called automatically by ROT, but can be called separately for testing. +func InitROT(c *config.Config) { + rotOnce.Do(func() { + rotInterval = c.Global.ROTInterval + rotchan = make(chan *config.OutputStats) + rotwg.Add(1) + go readStats() + }) +} + // ROT starts the Read Out Thread which will log statistics about what's being output // ROT is intended to be started as a goroutine which will log output every c. func ROT(c *config.Config) { - rotInterval = c.Global.ROTInterval - rotchan = make(chan *config.OutputStats) - rotwg.Add(1) - go readStats() + InitROT(c) lastEventsWritten := make(map[string]int64) lastBytesWritten := make(map[string]int64) @@ -83,6 +93,8 @@ func ROT(c *config.Config) { func ReadFinal() { close(rotchan) rotwg.Wait() + // Reset so ROT can be re-initialized (needed for tests) + rotOnce = sync.Once{} totalEvents := int64(0) totalBytes := int64(0) @@ -158,15 +170,7 @@ func write(item *config.OutQueueItem) { } tempbytes, err = w.Write(jb) case "splunkhec": - if _, ok := line["_raw"]; ok { - line["event"] = line["_raw"] - delete(line, "_raw") - } - if _, ok := line["_time"]; ok { - line["time"] = line["_time"] - delete(line, "_time") - } - // TODO Refactor to avoid copy pasta, being lazy for now + template.TransformHECFields(line) jb, err := json.Marshal(line) if err != nil { log.Errorf("Error marshaling json: %s", err) diff --git a/outputter/send_test.go b/outputter/send_test.go new file mode 100644 index 0000000..1f52406 --- /dev/null +++ b/outputter/send_test.go @@ -0,0 +1,682 @@ +package outputter + +import ( + "bytes" + "io" + "math/rand" + "net" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + config "github.com/coccyx/gogen/internal" + "github.com/stretchr/testify/assert" +) + +func TestSetup(t *testing.T) { + tests := []struct { + name string + outputter string + expectedType interface{} + }{ + {"stdout", "stdout", &stdout{}}, + {"devnull", "devnull", &devnull{}}, + {"file", "file", &file{}}, + {"http", "http", &httpout{}}, + {"buf", "buf", &buf{}}, + {"network", "network", &network{}}, + {"kafka", "kafka", &kafkaout{}}, + {"unknown defaults to stdout", "unknowntype", &stdout{}}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Use a unique gout slot for each test + num := 99 // use last slot to avoid conflicts + gout[num] = nil + + s := &config.Sample{ + Name: "test", + Output: &config.Output{ + Outputter: tc.outputter, + }, + } + item := &config.OutQueueItem{S: s} + source := rand.NewSource(0) + gen := rand.New(source) + + result := setup(gen, item, num) + assert.IsType(t, tc.expectedType, result) + + // Clean up + gout[num] = nil + }) + } +} + +func TestDevnullSend(t *testing.T) { + d := &devnull{} + oio := config.NewOutputIO() + item := &config.OutQueueItem{ + S: &config.Sample{Name: "test"}, + IO: oio, + } + + go func() { + io.WriteString(oio.W, "test data") + oio.W.Close() + }() + + err := d.Send(item) + assert.NoError(t, err) +} + +func TestDevnullClose(t *testing.T) { + d := &devnull{} + err := d.Close() + assert.NoError(t, err) +} + +func TestBufSend(t *testing.T) { + var b bytes.Buffer + s := &config.Sample{ + Name: "test", + Buf: &b, + } + oio := config.NewOutputIO() + item := &config.OutQueueItem{ + S: s, + IO: oio, + } + + go func() { + io.WriteString(oio.W, "buffered data\n") + oio.W.Close() + }() + + bu := &buf{} + err := bu.Send(item) + assert.NoError(t, err) + assert.Equal(t, "buffered data\n", b.String()) +} + +func TestStdoutSend(t *testing.T) { + // Redirect stdout to a pipe + origStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + oio := config.NewOutputIO() + item := &config.OutQueueItem{ + S: &config.Sample{Name: "test"}, + IO: oio, + } + + go func() { + io.WriteString(oio.W, "stdout data\n") + oio.W.Close() + }() + + so := &stdout{} + err := so.Send(item) + assert.NoError(t, err) + + w.Close() + var buf bytes.Buffer + io.Copy(&buf, r) + os.Stdout = origStdout + + assert.Equal(t, "stdout data\n", buf.String()) +} + +func TestStdoutClose(t *testing.T) { + so := &stdout{} + err := so.Close() + assert.NoError(t, err) +} + +func TestFileSendAndRotation(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "testfile.log") + + s := &config.Sample{ + Name: "filesample", + Output: &config.Output{ + FileName: filename, + MaxBytes: 50, // Very small to trigger rotation + BackupFiles: 2, + }, + } + + f := &file{} + + // Write enough data to trigger rotation + for i := 0; i < 5; i++ { + oio := config.NewOutputIO() + item := &config.OutQueueItem{S: s, IO: oio} + + go func() { + io.WriteString(oio.W, strings.Repeat("X", 30)+"\n") + oio.W.Close() + }() + + err := f.Send(item) + assert.NoError(t, err) + } + + // Check that backup files were created + _, err := os.Stat(filename) + assert.NoError(t, err, "main file should exist") + + _, err = os.Stat(filename + ".1") + assert.NoError(t, err, "backup .1 should exist") + + f.Close() +} + +func TestFileSendExistingFile(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "existing.log") + + // Pre-create the file with some content + os.WriteFile(filename, []byte("pre-existing data\n"), 0644) + + s := &config.Sample{ + Name: "fileexisting", + Output: &config.Output{ + FileName: filename, + MaxBytes: 10000000, + BackupFiles: 2, + }, + } + + f := &file{} + + oio := config.NewOutputIO() + item := &config.OutQueueItem{S: s, IO: oio} + + go func() { + io.WriteString(oio.W, "appended data\n") + oio.W.Close() + }() + + err := f.Send(item) + assert.NoError(t, err) + + // Verify the file has both old and new data + data, _ := os.ReadFile(filename) + assert.Contains(t, string(data), "pre-existing data") + assert.Contains(t, string(data), "appended data") + + f.Close() +} + +func TestFileClose(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "closefile.log") + + s := &config.Sample{ + Name: "fileclose", + Output: &config.Output{ + FileName: filename, + MaxBytes: 1000000, + BackupFiles: 2, + }, + } + + f := &file{} + + oio := config.NewOutputIO() + item := &config.OutQueueItem{S: s, IO: oio} + go func() { + io.WriteString(oio.W, "data\n") + oio.W.Close() + }() + f.Send(item) + + // Close should work + err := f.Close() + assert.NoError(t, err) + + // Close again should be idempotent + err = f.Close() + assert.NoError(t, err) +} + +func TestHTTPSendAndFlush(t *testing.T) { + var received bytes.Buffer + var mu sync.Mutex + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + io.Copy(&received, r.Body) + mu.Unlock() + w.WriteHeader(200) + })) + defer ts.Close() + + s := &config.Sample{ + Name: "httpsample", + Output: &config.Output{ + Endpoints: []string{ts.URL}, + BufferBytes: 10, // Small buffer to trigger flush + Headers: map[string]string{"Content-Type": "application/json"}, + Timeout: 5 * time.Second, + }, + } + + h := &httpout{} + + // Send enough data to exceed buffer and trigger flush + oio := config.NewOutputIO() + item := &config.OutQueueItem{S: s, IO: oio} + go func() { + io.WriteString(oio.W, strings.Repeat("D", 50)+"\n") + oio.W.Close() + }() + + err := h.Send(item) + assert.NoError(t, err) + + // Verify server received data + mu.Lock() + data := received.String() + mu.Unlock() + assert.NotEmpty(t, data, "HTTP server should have received data") + + // Close flushes remaining data + err = h.Close() + assert.NoError(t, err) +} + +func TestNetworkSend(t *testing.T) { + // Start a TCP listener on a random port + ln, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + defer ln.Close() + + var received bytes.Buffer + done := make(chan struct{}) + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + io.Copy(&received, conn) + conn.Close() + close(done) + }() + + s := &config.Sample{ + Name: "netsample", + Output: &config.Output{ + Endpoints: []string{ln.Addr().String()}, + Protocol: "tcp", + Timeout: 5 * time.Second, + }, + } + + n := &network{} + + oio := config.NewOutputIO() + item := &config.OutQueueItem{S: s, IO: oio} + go func() { + io.WriteString(oio.W, "network data\n") + oio.W.Close() + }() + + err = n.Send(item) + assert.NoError(t, err) + + // Close the connection so the listener goroutine can finish + n.Close() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for network data") + } + + assert.Equal(t, "network data\n", received.String()) +} + +func TestNetworkClose(t *testing.T) { + n := &network{} + // Close with no connection should not error + err := n.Close() + assert.NoError(t, err) + assert.False(t, n.initialized) +} + +func TestBufClose(t *testing.T) { + b := &buf{} + err := b.Close() + assert.NoError(t, err) +} + +func TestStartDevnullWorker(t *testing.T) { + cleanup := initROT() + defer cleanup() + + // Reset gout slot + gout[0] = nil + + oq := make(chan *config.OutQueueItem) + oqs := make(chan int) + + go Start(oq, oqs, 0) + + // Send an item through the pipeline + s := &config.Sample{ + Name: "starttest", + Output: &config.Output{ + Outputter: "devnull", + OutputTemplate: "raw", + }, + } + events := []map[string]string{ + {"_raw": "test event for start"}, + } + item := &config.OutQueueItem{ + S: s, + Events: events, + Cache: &config.CacheItem{}, + } + oq <- item + + // Close the queue and wait for worker to finish + close(oq) + select { + case <-oqs: + // Worker finished + case <-time.After(5 * time.Second): + t.Fatal("Start worker did not finish in time") + } + + // Verify gout slot was cleared + assert.Nil(t, gout[0]) +} + +func TestStartMultipleItems(t *testing.T) { + cleanup := initROT() + defer cleanup() + + gout[0] = nil + + oq := make(chan *config.OutQueueItem) + oqs := make(chan int) + + go Start(oq, oqs, 0) + + s := &config.Sample{ + Name: "multistart", + Output: &config.Output{ + Outputter: "devnull", + OutputTemplate: "raw", + }, + } + + for i := 0; i < 5; i++ { + events := []map[string]string{ + {"_raw": "event number"}, + } + item := &config.OutQueueItem{ + S: s, + Events: events, + Cache: &config.CacheItem{}, + } + oq <- item + } + + close(oq) + select { + case <-oqs: + case <-time.After(5 * time.Second): + t.Fatal("Start worker did not finish in time") + } + + // Check that events were accounted for + time.Sleep(50 * time.Millisecond) + Mutex.RLock() + ew := EventsWritten["multistart"] + Mutex.RUnlock() + assert.Equal(t, int64(5), ew) +} + +func TestStartEmptyEvents(t *testing.T) { + cleanup := initROT() + defer cleanup() + + gout[0] = nil + + oq := make(chan *config.OutQueueItem) + oqs := make(chan int) + + go Start(oq, oqs, 0) + + s := &config.Sample{ + Name: "emptyevents", + Output: &config.Output{ + Outputter: "devnull", + OutputTemplate: "raw", + }, + } + // Send item with no events - should skip the write/send + item := &config.OutQueueItem{ + S: s, + Events: []map[string]string{}, + Cache: &config.CacheItem{}, + } + oq <- item + + close(oq) + select { + case <-oqs: + case <-time.After(5 * time.Second): + t.Fatal("Start worker did not finish in time") + } +} + +func TestStartCloseOnChannelClose(t *testing.T) { + cleanup := initROT() + defer cleanup() + + gout[0] = nil + + oq := make(chan *config.OutQueueItem) + oqs := make(chan int) + + go Start(oq, oqs, 0) + + s := &config.Sample{ + Name: "closetest", + Output: &config.Output{ + Outputter: "devnull", + OutputTemplate: "raw", + }, + } + // Send one real item so lastS is set, then close + events := []map[string]string{ + {"_raw": "test event"}, + } + item := &config.OutQueueItem{ + S: s, + Events: events, + Cache: &config.CacheItem{}, + } + oq <- item + + // Close the channel - should trigger the Close() path on the outputter + close(oq) + select { + case <-oqs: + case <-time.After(5 * time.Second): + t.Fatal("Start worker did not finish in time") + } + + // gout should be cleared + assert.Nil(t, gout[0]) +} + +func TestStartSendError(t *testing.T) { + cleanup := initROT() + defer cleanup() + + // Use a network outputter pointed at a bad address to trigger Send error + gout[0] = nil + + oq := make(chan *config.OutQueueItem) + oqs := make(chan int) + + go Start(oq, oqs, 0) + + s := &config.Sample{ + Name: "senderror", + Output: &config.Output{ + Outputter: "network", + OutputTemplate: "raw", + Endpoints: []string{"127.0.0.1:1"}, // Should fail to connect + Protocol: "tcp", + }, + } + events := []map[string]string{ + {"_raw": "error event"}, + } + item := &config.OutQueueItem{ + S: s, + Events: events, + Cache: &config.CacheItem{}, + } + oq <- item + + close(oq) + select { + case <-oqs: + case <-time.After(10 * time.Second): + t.Fatal("Start worker did not finish in time") + } +} + +func TestHTTPFlushNon200(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + w.Write([]byte("internal server error")) + })) + defer ts.Close() + + s := &config.Sample{ + Name: "httpfail", + Output: &config.Output{ + Endpoints: []string{ts.URL}, + BufferBytes: 10, + Headers: map[string]string{"Content-Type": "text/plain"}, + Timeout: 5 * time.Second, + }, + } + + h := &httpout{} + + oio := config.NewOutputIO() + item := &config.OutQueueItem{S: s, IO: oio} + go func() { + io.WriteString(oio.W, strings.Repeat("X", 50)+"\n") + oio.W.Close() + }() + + err := h.Send(item) + // flush should return an error due to non-200 status + assert.Error(t, err) + assert.Contains(t, err.Error(), "500") +} + +func TestHTTPCloseFlushError(t *testing.T) { + // Use a server that accepts the first request (Send flush) but returns error on the second (Close flush) + calls := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls++ + if calls > 1 { + w.WriteHeader(500) + w.Write([]byte("flush error")) + return + } + w.WriteHeader(200) + })) + defer ts.Close() + + s := &config.Sample{ + Name: "httpclose", + Output: &config.Output{ + Endpoints: []string{ts.URL}, + BufferBytes: 10, // Small buffer to trigger flush on first Send + Headers: map[string]string{}, + Timeout: 5 * time.Second, + }, + } + + h := &httpout{} + + // First Send: exceeds buffer, triggers flush (call #1 → 200 OK) + oio := config.NewOutputIO() + item := &config.OutQueueItem{S: s, IO: oio} + go func() { + io.WriteString(oio.W, strings.Repeat("X", 50)) + oio.W.Close() + }() + + err := h.Send(item) + assert.NoError(t, err) + + // Add more data to buffer for Close to flush + h.buf.WriteString("leftover data") + + // Close should flush remaining data and get 500 error (call #2) + err = h.Close() + assert.Error(t, err) + assert.Contains(t, err.Error(), "500") +} + +func TestStartSendErrorRepeat(t *testing.T) { + cleanup := initROT() + defer cleanup() + + gout[0] = nil + + oq := make(chan *config.OutQueueItem) + oqs := make(chan int) + + go Start(oq, oqs, 0) + + s := &config.Sample{ + Name: "senderrorrepeat", + Output: &config.Output{ + Outputter: "network", + OutputTemplate: "raw", + Endpoints: []string{"127.0.0.1:1"}, + Protocol: "tcp", + }, + } + // Send multiple items to trigger repeat error path (lasterr[num].count++) + for i := 0; i < 3; i++ { + events := []map[string]string{ + {"_raw": "error event repeat"}, + } + item := &config.OutQueueItem{ + S: s, + Events: events, + Cache: &config.CacheItem{}, + } + oq <- item + } + + close(oq) + select { + case <-oqs: + case <-time.After(15 * time.Second): + t.Fatal("Start worker did not finish in time") + } +} diff --git a/outputter/write_test.go b/outputter/write_test.go new file mode 100644 index 0000000..d321cbe --- /dev/null +++ b/outputter/write_test.go @@ -0,0 +1,451 @@ +package outputter + +import ( + "bytes" + "encoding/json" + "io" + "strings" + "sync" + "testing" + "time" + + config "github.com/coccyx/gogen/internal" + "github.com/coccyx/gogen/template" + "github.com/stretchr/testify/assert" +) + +// initROT initializes the rotchan and readStats goroutine needed by write(). +// Returns a cleanup function to call via defer. +func initROT() func() { + Mutex.Lock() + BytesWritten = make(map[string]int64) + EventsWritten = make(map[string]int64) + rotwg = sync.WaitGroup{} + rotchan = make(chan *config.OutputStats) + Mutex.Unlock() + rotwg.Add(1) + go readStats() + return func() { + close(rotchan) + rotwg.Wait() + } +} + +func makeOutQueueItem(sampleName, outputTemplate, outputter string, events []map[string]string) *config.OutQueueItem { + s := &config.Sample{ + Name: sampleName, + Output: &config.Output{ + Outputter: outputter, + OutputTemplate: outputTemplate, + }, + } + oio := config.NewOutputIO() + return &config.OutQueueItem{ + S: s, + Events: events, + IO: oio, + Cache: &config.CacheItem{}, + } +} + +func readFromPipe(item *config.OutQueueItem) string { + var buf bytes.Buffer + io.Copy(&buf, item.IO.R) + return buf.String() +} + +func TestWriteRaw(t *testing.T) { + cleanup := initROT() + defer cleanup() + + events := []map[string]string{ + {"_raw": "hello world"}, + } + item := makeOutQueueItem("rawsample", "raw", "stdout", events) + + var result string + done := make(chan struct{}) + go func() { + result = readFromPipe(item) + close(done) + }() + + write(item) + <-done + + assert.Contains(t, result, "hello world") +} + +func TestWriteJSON(t *testing.T) { + cleanup := initROT() + defer cleanup() + + events := []map[string]string{ + {"_raw": "test event", "host": "myhost"}, + } + item := makeOutQueueItem("jsonsample", "json", "stdout", events) + + var result string + done := make(chan struct{}) + go func() { + result = readFromPipe(item) + close(done) + }() + + write(item) + <-done + + var parsed map[string]string + lines := strings.TrimSpace(result) + err := json.Unmarshal([]byte(strings.Split(lines, "\n")[0]), &parsed) + assert.NoError(t, err) + assert.Equal(t, "test event", parsed["_raw"]) + assert.Equal(t, "myhost", parsed["host"]) +} + +func TestWriteSplunkHEC(t *testing.T) { + cleanup := initROT() + defer cleanup() + + events := []map[string]string{ + {"_raw": "splunk event", "_time": "1234567890"}, + } + item := makeOutQueueItem("hecsample", "splunkhec", "stdout", events) + + var result string + done := make(chan struct{}) + go func() { + result = readFromPipe(item) + close(done) + }() + + write(item) + <-done + + var parsed map[string]string + lines := strings.TrimSpace(result) + err := json.Unmarshal([]byte(strings.Split(lines, "\n")[0]), &parsed) + assert.NoError(t, err) + assert.Equal(t, "splunk event", parsed["event"], "splunkhec should remap _raw to event") + assert.Equal(t, "1234567890", parsed["time"], "splunkhec should remap _time to time") + assert.Empty(t, parsed["_raw"], "_raw should be deleted") + assert.Empty(t, parsed["_time"], "_time should be deleted") +} + +func TestWriteRFC3164(t *testing.T) { + cleanup := initROT() + defer cleanup() + + events := []map[string]string{ + {"_raw": "syslog msg", "_time": "Oct 20 12:00:00", "priority": "13", "host": "myhost", "tag": "gogen", "pid": "1234"}, + } + item := makeOutQueueItem("rfc3164sample", "rfc3164", "stdout", events) + + var result string + done := make(chan struct{}) + go func() { + result = readFromPipe(item) + close(done) + }() + + write(item) + <-done + + assert.Contains(t, result, "<13>") + assert.Contains(t, result, "Oct 20 12:00:00") + assert.Contains(t, result, "myhost") + assert.Contains(t, result, "gogen[1234]") + assert.Contains(t, result, "syslog msg") +} + +func TestWriteRFC5424(t *testing.T) { + cleanup := initROT() + defer cleanup() + + events := []map[string]string{ + {"_raw": "syslog5424 msg", "_time": "2001-10-20T12:00:00Z", "priority": "13", "host": "myhost", "appName": "gogen", "pid": "1234", "extra": "val"}, + } + item := makeOutQueueItem("rfc5424sample", "rfc5424", "stdout", events) + + var result string + done := make(chan struct{}) + go func() { + result = readFromPipe(item) + close(done) + }() + + write(item) + <-done + + assert.Contains(t, result, "<13>1") + assert.Contains(t, result, "myhost") + assert.Contains(t, result, "gogen") + assert.Contains(t, result, "syslog5424 msg") + assert.Contains(t, result, "[meta") + assert.Contains(t, result, `extra="val"`) +} + +func TestWriteElasticsearch(t *testing.T) { + cleanup := initROT() + defer cleanup() + + events := []map[string]string{ + {"_raw": "es event", "index": "testindex"}, + } + item := makeOutQueueItem("essample", "elasticsearch", "stdout", events) + + var result string + done := make(chan struct{}) + go func() { + result = readFromPipe(item) + close(done) + }() + + write(item) + <-done + + assert.Contains(t, result, `"_index": "testindex"`) + assert.Contains(t, result, `"_type": "doc"`) + // _raw should be remapped to message + var parsed map[string]interface{} + lines := strings.Split(strings.TrimSpace(result), "\n") + assert.GreaterOrEqual(t, len(lines), 2, "elasticsearch should produce index header + body") + err := json.Unmarshal([]byte(lines[1]), &parsed) + assert.NoError(t, err) + assert.Equal(t, "es event", parsed["message"]) +} + +func TestWriteDevnull(t *testing.T) { + cleanup := initROT() + defer cleanup() + + events := []map[string]string{ + {"_raw": "devnull event data"}, + } + item := makeOutQueueItem("devnullsample", "raw", "devnull", events) + + var result string + done := make(chan struct{}) + go func() { + result = readFromPipe(item) + close(done) + }() + + write(item) + <-done + + // devnull should not write any content through the pipe + assert.Empty(t, result) + + // But bytes should still be accounted for + time.Sleep(50 * time.Millisecond) + Mutex.RLock() + bw := BytesWritten["devnullsample"] + ew := EventsWritten["devnullsample"] + Mutex.RUnlock() + assert.Greater(t, bw, int64(0), "bytes should be counted even for devnull") + assert.Equal(t, int64(1), ew) +} + +func TestWriteCacheMiss(t *testing.T) { + cleanup := initROT() + defer cleanup() + + events := []map[string]string{ + {"_raw": "cache miss event"}, + } + item := makeOutQueueItem("cachemiss", "raw", "stdout", events) + item.Cache.UseCache = true // UseCache=true but no cacheBuf exists => cache miss + + var result string + done := make(chan struct{}) + go func() { + result = readFromPipe(item) + close(done) + }() + + write(item) + <-done + + // Cache miss should still write through the normal pipe + assert.Contains(t, result, "cache miss event") +} + +func TestWriteSetCache(t *testing.T) { + cleanup := initROT() + defer cleanup() + + // Clean the cache bufs + cacheMutex.Lock() + delete(cacheBufs, "setcache") + cacheMutex.Unlock() + + events := []map[string]string{ + {"_raw": "cached event"}, + } + item := makeOutQueueItem("setcache", "raw", "stdout", events) + item.Cache.SetCache = true + + var result string + done := make(chan struct{}) + go func() { + result = readFromPipe(item) + close(done) + }() + + write(item) + <-done + + // SetCache should write to both cache buffer and the pipe + assert.Contains(t, result, "cached event") + + // Verify cache buffer was populated + cacheMutex.RLock() + cb, ok := cacheBufs["setcache"] + cacheMutex.RUnlock() + assert.True(t, ok, "cache buffer should be created") + assert.Contains(t, cb.String(), "cached event") +} + +func TestWriteUseCache(t *testing.T) { + cleanup := initROT() + defer cleanup() + + // Pre-populate the cache + cacheMutex.Lock() + cacheBufs["usecache"] = &bytes.Buffer{} + cacheBufs["usecache"].WriteString("previously cached data\n") + cacheMutex.Unlock() + + events := []map[string]string{ + {"_raw": "new event"}, + } + item := makeOutQueueItem("usecache", "raw", "stdout", events) + item.Cache.UseCache = true + + var result string + done := make(chan struct{}) + go func() { + result = readFromPipe(item) + close(done) + }() + + write(item) + <-done + + // Should use the cached data, not the new events + assert.Contains(t, result, "previously cached data") + + // Clean up + cacheMutex.Lock() + delete(cacheBufs, "usecache") + cacheMutex.Unlock() +} + +func TestWriteNonExistentTemplate(t *testing.T) { + cleanup := initROT() + defer cleanup() + + events := []map[string]string{ + {"_raw": "should not appear"}, + } + item := makeOutQueueItem("badtemplate", "nonexistent_template_xyz", "stdout", events) + + var result string + done := make(chan struct{}) + go func() { + result = readFromPipe(item) + close(done) + }() + + write(item) + <-done + + // Non-existent template should produce no output + assert.Empty(t, result) +} + +func TestWriteMultipleEvents(t *testing.T) { + cleanup := initROT() + defer cleanup() + + events := []map[string]string{ + {"_raw": "event1"}, + {"_raw": "event2"}, + {"_raw": "event3"}, + } + item := makeOutQueueItem("multisample", "raw", "stdout", events) + + var result string + done := make(chan struct{}) + go func() { + result = readFromPipe(item) + close(done) + }() + + write(item) + <-done + + assert.Contains(t, result, "event1") + assert.Contains(t, result, "event2") + assert.Contains(t, result, "event3") + + // Verify accounting + time.Sleep(50 * time.Millisecond) + Mutex.RLock() + ew := EventsWritten["multisample"] + Mutex.RUnlock() + assert.Equal(t, int64(3), ew) +} + +func TestWriteKafkaNoNewlines(t *testing.T) { + cleanup := initROT() + defer cleanup() + + events := []map[string]string{ + {"_raw": "kafka event"}, + } + item := makeOutQueueItem("kafkasample", "raw", "kafka", events) + + var result string + done := make(chan struct{}) + go func() { + result = readFromPipe(item) + close(done) + }() + + write(item) + <-done + + // Kafka should not append newlines + assert.Equal(t, "kafka event", result) +} + +func TestWriteCustomTemplate(t *testing.T) { + cleanup := initROT() + defer cleanup() + + // Register a custom template + _ = template.New("customtest_header", "HEADER\n") + _ = template.New("customtest_row", "ROW:{{._raw}}\n") + _ = template.New("customtest_footer", "FOOTER\n") + + events := []map[string]string{ + {"_raw": "custom line"}, + } + item := makeOutQueueItem("customsample", "customtest", "stdout", events) + + var result string + done := make(chan struct{}) + go func() { + result = readFromPipe(item) + close(done) + }() + + write(item) + <-done + + assert.Contains(t, result, "HEADER") + assert.Contains(t, result, "ROW:custom line") + assert.Contains(t, result, "FOOTER") +} diff --git a/rater/rater_test.go b/rater/rater_test.go index 470d3db..047aef5 100644 --- a/rater/rater_test.go +++ b/rater/rater_test.go @@ -1,10 +1,13 @@ package rater import ( + "os" + "path/filepath" "testing" "time" config "github.com/coccyx/gogen/internal" + "github.com/coccyx/gogen/outputter" "github.com/stretchr/testify/assert" ) @@ -14,3 +17,144 @@ func TestRandomizeCount(t *testing.T) { count := EventRate(s, time.Now(), 10) assert.Equal(t, 11, count) } + +func TestTokenRateDefault(t *testing.T) { + dr := &DefaultRater{} + token := config.Token{Name: "test"} + rate := dr.TokenRate(token, time.Now()) + assert.Equal(t, 1.0, rate) +} + +func TestTokenRateConfig(t *testing.T) { + os.Setenv("GOGEN_HOME", "..") + os.Setenv("GOGEN_ALWAYS_REFRESH", "1") + os.Setenv("GOGEN_FULLCONFIG", filepath.Join("..", "tests", "rater", "configrater.yml")) + + c := config.NewConfig() + r := c.FindRater("testconfigrater") + + cr := &ConfigRater{c: r} + token := config.Token{Name: "test"} + + loc, _ := time.LoadLocation("Local") + n := time.Date(2001, 10, 20, 0, 0, 0, 100000, loc) + rate := cr.TokenRate(token, n) + assert.Equal(t, 2.0, rate) +} + +func TestTokenRateKBps(t *testing.T) { + kr := &KBpsRater{c: &config.RaterConfig{Name: "kbps"}} + token := config.Token{Name: "test"} + rate := kr.TokenRate(token, time.Now()) + assert.Equal(t, 1.0, rate) +} + +func TestTokenRateScript(t *testing.T) { + os.Setenv("GOGEN_HOME", "..") + os.Setenv("GOGEN_ALWAYS_REFRESH", "1") + os.Setenv("GOGEN_FULLCONFIG", filepath.Join("..", "tests", "rater", "luarater.yml")) + + c := config.NewConfig() + r := c.FindRater("multiply") + + sr := &ScriptRater{c: r} + token := config.Token{Name: "test"} + rate := sr.TokenRate(token, time.Now()) + assert.Equal(t, 2.0, rate) +} + +func TestGetRaterFallback(t *testing.T) { + os.Setenv("GOGEN_HOME", "..") + os.Setenv("GOGEN_ALWAYS_REFRESH", "1") + os.Setenv("GOGEN_FULLCONFIG", filepath.Join("..", "tests", "rater", "defaultrater.yml")) + + r := GetRater("nonexistentrater") + assert.IsType(t, &DefaultRater{}, r, "unknown rater name should fall back to DefaultRater") +} + +func TestEventRateNegativeResult(t *testing.T) { + // A rater returning a negative rate should produce a negative or zero count + s := &config.Sample{RandomizeCount: 0} + // Use a mock rater by pre-setting it + s.Rater = &negativeRater{} + randSource = 2 + count := EventRate(s, time.Now(), 10) + assert.True(t, count <= 0, "negative rate should produce non-positive count, got %d", count) +} + +func TestKBpsEventRateMissingOption(t *testing.T) { + kr := &KBpsRater{ + c: &config.RaterConfig{ + Name: "kbps", + Options: map[string]interface{}{}, + }, + } + s := &config.Sample{Name: "test"} + rate := kr.EventRate(s, time.Now(), 10) + assert.Equal(t, 1.0, rate) +} + +func TestKBpsEventRateWrongType(t *testing.T) { + kr := &KBpsRater{ + c: &config.RaterConfig{ + Name: "kbps", + Options: map[string]interface{}{ + "KBps": "not_a_float", + }, + }, + } + s := &config.Sample{Name: "test"} + rate := kr.EventRate(s, time.Now(), 10) + assert.Equal(t, 1.0, rate) +} + +func TestKBpsEventRateMissingSample(t *testing.T) { + kr := &KBpsRater{ + c: &config.RaterConfig{ + Name: "kbps", + Options: map[string]interface{}{ + "KBps": 100.0, + }, + }, + } + s := &config.Sample{Name: "nonexistent_sample_kbps"} + rate := kr.EventRate(s, time.Now(), 10) + assert.Equal(t, 1.0, rate) +} + +func TestKBpsEventRateWithData(t *testing.T) { + // Pre-populate outputter stats + outputter.Mutex.Lock() + outputter.BytesWritten["kbps_test_sample"] = 10000 + outputter.EventsWritten["kbps_test_sample"] = 100 + outputter.Mutex.Unlock() + defer func() { + outputter.Mutex.Lock() + delete(outputter.BytesWritten, "kbps_test_sample") + delete(outputter.EventsWritten, "kbps_test_sample") + outputter.Mutex.Unlock() + }() + + kr := &KBpsRater{ + c: &config.RaterConfig{ + Name: "kbps", + Options: map[string]interface{}{ + "KBps": 100.0, + }, + }, + t: time.Now().Add(-1 * time.Second), // pretend we started 1s ago + } + s := &config.Sample{Name: "kbps_test_sample"} + rate := kr.EventRate(s, time.Now(), 10) + assert.Equal(t, 1.0, rate) // always returns 1.0 regardless +} + +// negativeRater always returns a negative rate for testing +type negativeRater struct{} + +func (nr *negativeRater) EventRate(s *config.Sample, now time.Time, count int) float64 { + return -1.0 +} +func (nr *negativeRater) TokenRate(t config.Token, now time.Time) float64 { + return -1.0 +} diff --git a/run/run_test.go b/run/run_test.go new file mode 100644 index 0000000..17b8bc1 --- /dev/null +++ b/run/run_test.go @@ -0,0 +1,200 @@ +package run + +import ( + "bytes" + "testing" + "time" + + config "github.com/coccyx/gogen/internal" + "github.com/coccyx/gogen/outputter" + "github.com/stretchr/testify/assert" +) + +// resetRunState resets config and outputter stats for a clean test. +func resetRunState() { + config.ResetConfig() + outputter.Mutex.Lock() + outputter.BytesWritten = make(map[string]int64) + outputter.EventsWritten = make(map[string]int64) + outputter.Mutex.Unlock() +} + +func TestRunCompletesWithEndIntervals(t *testing.T) { + resetRunState() + + configStr := ` +global: + utc: true + output: + outputter: devnull + outputTemplate: raw + rotInterval: 1 +samples: + - name: runtest + description: "Run completion test" + interval: 1 + count: 1 + endIntervals: 1 + lines: + - _raw: run test event +` + config.SetupFromString(configStr) + defer config.CleanupConfigAndEnvironment() + + c := config.NewConfig() + assert.NotEmpty(t, c.Samples) + + done := make(chan struct{}) + go func() { + Run(c) + close(done) + }() + + select { + case <-done: + // Run completed successfully + case <-time.After(10 * time.Second): + t.Fatal("Run() did not complete within timeout") + } +} + +func TestRunMultipleSamples(t *testing.T) { + resetRunState() + + configStr := ` +global: + utc: true + output: + outputter: devnull + outputTemplate: raw + rotInterval: 1 +samples: + - name: multi1 + description: "Multi sample 1" + interval: 1 + count: 1 + endIntervals: 1 + lines: + - _raw: event from multi1 + - name: multi2 + description: "Multi sample 2" + interval: 1 + count: 1 + endIntervals: 1 + lines: + - _raw: event from multi2 +` + config.SetupFromString(configStr) + defer config.CleanupConfigAndEnvironment() + + c := config.NewConfig() + assert.Len(t, c.Samples, 2) + + done := make(chan struct{}) + go func() { + Run(c) + close(done) + }() + + select { + case <-done: + outputter.Mutex.RLock() + totalEvents := int64(0) + for _, v := range outputter.EventsWritten { + totalEvents += v + } + outputter.Mutex.RUnlock() + assert.Greater(t, totalEvents, int64(0), "should have generated events") + case <-time.After(10 * time.Second): + t.Fatal("Run() did not complete within timeout") + } +} + +func TestOnceMethod(t *testing.T) { + resetRunState() + + configStr := ` +global: + utc: true + output: + outputter: buf + outputTemplate: json + rotInterval: 1 +samples: + - name: oncemethodtest + description: "Once method test" + interval: 1 + count: 1 + endIntervals: 1 + lines: + - _raw: once method event +` + config.SetupFromString(configStr) + defer config.CleanupConfigAndEnvironment() + + r := Runner{} + + done := make(chan struct{}) + go func() { + r.Once("oncemethodtest") + close(done) + }() + + select { + case <-done: + // Once completed without error + case <-time.After(10 * time.Second): + t.Fatal("Once() did not complete within timeout") + } +} + +func TestOncePublic(t *testing.T) { + resetRunState() + + configStr := ` +global: + utc: true + output: + outputter: buf + outputTemplate: json + rotInterval: 1 +samples: + - name: oncetest + description: "Once test sample" + interval: 1 + count: 1 + endIntervals: 1 + lines: + - _raw: once event data +` + config.SetupFromString(configStr) + defer config.CleanupConfigAndEnvironment() + + c := config.NewConfig() + assert.NotEmpty(t, c.Samples) + + // Set up a buffer for the sample + var buf bytes.Buffer + s := c.FindSampleByName("oncetest") + s.Buf = &buf + + r := Runner{} + + // Start ROT before onceWithConfig (Once() normally does this) + go outputter.ROT(c) + // Give ROT goroutine time to create the new rotchan + time.Sleep(50 * time.Millisecond) + + done := make(chan struct{}) + go func() { + r.onceWithConfig("oncetest", c) + close(done) + }() + + select { + case <-done: + assert.Contains(t, buf.String(), "once event data") + case <-time.After(10 * time.Second): + t.Fatal("Once() did not complete within timeout") + } +} diff --git a/run/runonce_test.go b/run/runonce_test.go index e40ec2d..2b83294 100644 --- a/run/runonce_test.go +++ b/run/runonce_test.go @@ -7,12 +7,12 @@ import ( "time" config "github.com/coccyx/gogen/internal" + "github.com/coccyx/gogen/outputter" "github.com/stretchr/testify/assert" ) func TestOnceWithConfig(t *testing.T) { - // Clean up any existing config - config.ResetConfig() + resetRunState() // Setup test configuration configStr := ` @@ -58,6 +58,9 @@ samples: config.SetupFromString(configStr) c := config.NewConfig() + // Start ROT in case a previous test closed rotchan + go outputter.ROT(c) + // Record time before and after test to validate timestamp is within range beforeTest := time.Now().Truncate(time.Second) if c.Global.UTC { diff --git a/template/hec.go b/template/hec.go new file mode 100644 index 0000000..a32625c --- /dev/null +++ b/template/hec.go @@ -0,0 +1,14 @@ +package template + +// TransformHECFields renames Splunk internal fields to HEC format: +// _raw -> event, _time -> time. +func TransformHECFields(event map[string]string) { + if v, ok := event["_raw"]; ok { + event["event"] = v + delete(event, "_raw") + } + if v, ok := event["_time"]; ok { + event["time"] = v + delete(event, "_time") + } +} diff --git a/template/template.go b/template/template.go index 6ca090c..3fa2e67 100644 --- a/template/template.go +++ b/template/template.go @@ -24,19 +24,18 @@ func New(name string, template string) error { if _, ok := cache[name]; !ok { funcMap := ttemplate.FuncMap{ "json": func(v interface{}) string { - a, _ := json.Marshal(v) + a, err := json.Marshal(v) + if err != nil { + return fmt.Sprintf("json marshal error: %v", err) + } return string(a) }, "splunkhec": func(v interface{}) string { - if _, ok := v.(map[string]string)["_raw"]; ok { - v.(map[string]string)["event"] = v.(map[string]string)["_raw"] - delete(v.(map[string]string), "_raw") - } - if _, ok = v.(map[string]string)["_time"]; ok { - v.(map[string]string)["time"] = v.(map[string]string)["_time"] - delete(v.(map[string]string), "_time") + TransformHECFields(v.(map[string]string)) + a, err := json.Marshal(v) + if err != nil { + return fmt.Sprintf("json marshal error: %v", err) } - a, _ := json.Marshal(v) return string(a) }, "keys": func(m map[string]string) []string { diff --git a/tests/generator/luaapi2.yml b/tests/generator/luaapi2.yml new file mode 100644 index 0000000..2edf1dc --- /dev/null +++ b/tests/generator/luaapi2.yml @@ -0,0 +1,43 @@ +generators: + - name: roundTest + script: | + val = round(3.14159, 2) + setToken("rounded", tostring(val)) + - name: logInfoTest + script: | + info("test log message from lua") + setToken("logged", "ok") + - name: removeTokenTest + script: | + setToken("keeper", "keep") + setToken("remover", "remove") + removeToken("remover") + - name: sendEventTest + script: | + event = { _raw = "sent via sendEvent" } + sendEvent(event) +samples: + - name: roundTest + generator: roundTest + interval: 1 + endIntervals: 1 + lines: + - _raw: notused + - name: logInfoTest + generator: logInfoTest + interval: 1 + endIntervals: 1 + lines: + - _raw: notused + - name: removeTokenTest + generator: removeTokenTest + interval: 1 + endIntervals: 1 + lines: + - _raw: notused + - name: sendEventTest + generator: sendEventTest + interval: 1 + endIntervals: 1 + lines: + - _raw: notused diff --git a/tests/integration_test.go b/tests/integration_test.go new file mode 100644 index 0000000..e7ebedd --- /dev/null +++ b/tests/integration_test.go @@ -0,0 +1,842 @@ +package tests + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "os" + "sort" + "strconv" + "strings" + "sync" + "testing" + "time" + + config "github.com/coccyx/gogen/internal" + "github.com/coccyx/gogen/outputter" + "github.com/coccyx/gogen/run" + "github.com/coccyx/gogen/template" + "github.com/stretchr/testify/assert" +) + +// resetState clears the config singleton and outputter statistics. +func resetState() { + config.ResetConfig() + outputter.Mutex.Lock() + outputter.BytesWritten = make(map[string]int64) + outputter.EventsWritten = make(map[string]int64) + outputter.Mutex.Unlock() +} + +// captureStdoutRun sets up a config from a YAML string, captures stdout during +// run.Run, and returns the captured output. +func captureStdoutRun(t *testing.T, configStr string) string { + t.Helper() + resetState() + + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + defer func() { os.Stdout = oldStdout }() + + config.SetupFromString(configStr) + c := config.NewConfig() + defer config.CleanupConfigAndEnvironment() + + done := make(chan bool) + go func() { + run.Run(c) + w.Close() + done <- true + }() + + var buf bytes.Buffer + _, err := buf.ReadFrom(r) + assert.NoError(t, err) + <-done + + return strings.TrimSpace(buf.String()) +} + +// runDevnull sets up a config from a YAML string, runs the pipeline with devnull +// output, and returns the config for inspection. +func runDevnull(t *testing.T, configStr string) *config.Config { + t.Helper() + resetState() + config.SetupFromString(configStr) + c := config.NewConfig() + run.Run(c) + config.CleanupConfigAndEnvironment() + return c +} + +// --------------------------------------------------------------------------- +// 1. Config Defaults +// --------------------------------------------------------------------------- + +func TestConfigDefaultsApplied(t *testing.T) { + resetState() + configStr := ` +samples: + - name: defaultsample + begin: "2001-10-20 00:00:00" + end: "2001-10-20 00:00:01" + interval: 1 + count: 1 + lines: + - _raw: hello +` + config.SetupFromString(configStr) + c := config.NewConfig() + defer config.CleanupConfigAndEnvironment() + + // Global worker/queue defaults + assert.Equal(t, 1, c.Global.GeneratorWorkers, "GeneratorWorkers default") + assert.Equal(t, 1, c.Global.OutputWorkers, "OutputWorkers default") + assert.Equal(t, 50, c.Global.GeneratorQueueLength, "GeneratorQueueLength default") + assert.Equal(t, 10, c.Global.OutputQueueLength, "OutputQueueLength default") + + // Output defaults + assert.Equal(t, "stdout", c.Global.Output.Outputter, "Outputter default") + assert.Equal(t, "raw", c.Global.Output.OutputTemplate, "OutputTemplate default") + assert.Equal(t, "/tmp/test.log", c.Global.Output.FileName, "FileName default") + assert.Equal(t, int64(10485760), c.Global.Output.MaxBytes, "MaxBytes default") + assert.Equal(t, 5, c.Global.Output.BackupFiles, "BackupFiles default") + assert.Equal(t, 4096, c.Global.Output.BufferBytes, "BufferBytes default") + assert.Equal(t, 10*time.Second, c.Global.Output.Timeout, "Timeout default") + assert.Equal(t, "defaultTopic", c.Global.Output.Topic, "Topic default") + assert.Equal(t, "application/json", c.Global.Output.Headers["Content-Type"], "Headers default") + + // ROT + assert.Equal(t, 1, c.Global.ROTInterval, "ROTInterval default") +} + +func TestConfigRaterDefaultMaps(t *testing.T) { + resetState() + configStr := ` +samples: + - name: ratersample + begin: "2001-10-20 00:00:00" + end: "2001-10-20 00:00:01" + interval: 1 + count: 1 + lines: + - _raw: hello +` + config.SetupFromString(configStr) + c := config.NewConfig() + defer config.CleanupConfigAndEnvironment() + + r := c.FindRater("config") + if !assert.NotNil(t, r, "config rater should exist") { + return + } + + // HourOfDay: 24 entries, keys 0-23, all 1.0 + hod := r.Options["HourOfDay"].(map[int]float64) + assert.Len(t, hod, 24) + for i := 0; i < 24; i++ { + assert.Equal(t, 1.0, hod[i], "HourOfDay[%d]", i) + } + + // DayOfWeek: 7 entries, keys 0-6, all 1.0 + dow := r.Options["DayOfWeek"].(map[int]float64) + assert.Len(t, dow, 7) + for i := 0; i < 7; i++ { + assert.Equal(t, 1.0, dow[i], "DayOfWeek[%d]", i) + } + + // MinuteOfHour: 60 entries, keys 0-59, all 1.0 + moh := r.Options["MinuteOfHour"].(map[int]float64) + assert.Len(t, moh, 60) + for i := 0; i < 60; i++ { + assert.Equal(t, 1.0, moh[i], "MinuteOfHour[%d]", i) + } +} + +// --------------------------------------------------------------------------- +// 2. Config Parsing +// --------------------------------------------------------------------------- + +func TestConfigParseValidFullConfig(t *testing.T) { + resetState() + configStr := ` +global: + output: + outputter: devnull + outputTemplate: raw +samples: + - name: tutorial1 + begin: "2001-10-20 00:00:00" + end: "2001-10-20 00:00:01" + interval: 1 + count: 3 + tokens: + - name: ts + format: template + token: $ts$ + type: timestamp + replacement: "%d/%b/%Y %H:%M:%S" + lines: + - _raw: "$ts$ line1" + - _raw: "$ts$ line2" +` + config.SetupFromString(configStr) + c := config.NewConfig() + defer config.CleanupConfigAndEnvironment() + + assert.Len(t, c.Samples, 1) + s := c.Samples[0] + assert.Equal(t, "tutorial1", s.Name) + assert.Equal(t, 3, s.Count) + assert.Equal(t, 1, s.Interval) + assert.GreaterOrEqual(t, len(s.Tokens), 1, "should have at least the ts token") + assert.Len(t, s.Lines, 2) +} + +func TestConfigParseInvalidYAML(t *testing.T) { + resetState() + // BuildConfig panics/fatals on invalid YAML via log.Panic + tmpfile, err := os.CreateTemp("", "gogen-test-bad-*.yml") + assert.NoError(t, err) + _, err = tmpfile.Write([]byte("{{{{invalid yaml!!!!")) + assert.NoError(t, err) + tmpfile.Close() + defer os.Remove(tmpfile.Name()) + + defer func() { + r := recover() + assert.NotNil(t, r, "BuildConfig should panic on invalid YAML") + }() + + config.BuildConfig(config.ConfigConfig{FullConfig: tmpfile.Name()}) +} + +// --------------------------------------------------------------------------- +// 3. Full Pipeline — Output Templates +// --------------------------------------------------------------------------- + +func TestPipelineRawOutput(t *testing.T) { + output := captureStdoutRun(t, ` +global: + output: + outputter: stdout + outputTemplate: raw +samples: + - name: rawtest + begin: "2001-10-20 00:00:00" + end: "2001-10-20 00:00:01" + interval: 1 + count: 1 + lines: + - _raw: "hello raw world" +`) + assert.Equal(t, "hello raw world", output) +} + +func TestPipelineJSONOutput(t *testing.T) { + output := captureStdoutRun(t, ` +global: + output: + outputter: stdout + outputTemplate: json +samples: + - name: jsontest + begin: "2001-10-20 00:00:00" + end: "2001-10-20 00:00:01" + interval: 1 + count: 1 + tokens: + - name: tsepoch + format: template + token: $epochts$ + field: _time + type: timestamp + replacement: "%s.%L" + lines: + - sourcetype: jtest + source: gogen + host: gogen + index: main + _time: $epochts$ + _raw: hello json +`) + lines := strings.Split(output, "\n") + assert.Equal(t, 1, len(lines), "expected one line") + + var data map[string]interface{} + err := json.Unmarshal([]byte(lines[0]), &data) + assert.NoError(t, err) + + for _, field := range []string{"_raw", "host", "source", "sourcetype", "index"} { + assert.Contains(t, data, field, "missing field %s", field) + } + assert.Equal(t, "jtest", data["sourcetype"]) + assert.Equal(t, "gogen", data["source"]) + assert.Equal(t, "gogen", data["host"]) + assert.Equal(t, "main", data["index"]) + assert.Equal(t, "hello json", data["_raw"]) +} + +func TestPipelineSplunkHECOutput(t *testing.T) { + output := captureStdoutRun(t, ` +global: + output: + outputter: stdout + outputTemplate: splunkhec +samples: + - name: hectest + begin: "2001-10-20 00:00:00" + end: "2001-10-20 00:00:01" + interval: 1 + count: 1 + tokens: + - name: tsepoch + format: template + token: $epochts$ + field: _time + type: timestamp + replacement: "%s.%L" + lines: + - sourcetype: hectype + source: gogen + host: gogen + index: main + _time: $epochts$ + _raw: hec event data +`) + lines := strings.Split(output, "\n") + assert.Equal(t, 1, len(lines), "expected one line") + + var data map[string]interface{} + err := json.Unmarshal([]byte(lines[0]), &data) + assert.NoError(t, err) + + // _raw renamed to event, _time renamed to time + assert.Contains(t, data, "event") + assert.Contains(t, data, "time") + assert.NotContains(t, data, "_raw") + assert.NotContains(t, data, "_time") + assert.Equal(t, "hec event data", data["event"]) +} + +func TestPipelineCSVOutput(t *testing.T) { + output := captureStdoutRun(t, ` +global: + output: + outputter: stdout + outputTemplate: csv +samples: + - name: csvtest + begin: "2001-10-20 00:00:00" + end: "2001-10-20 00:00:01" + interval: 1 + count: 1 + lines: + - _raw: csvdata + host: myhost + source: mysource +`) + lines := strings.Split(output, "\n") + assert.GreaterOrEqual(t, len(lines), 2, "expected header + data rows") + + // Header should have sorted field names + header := lines[0] + fields := strings.Split(header, ",") + sorted := make([]string, len(fields)) + copy(sorted, fields) + sort.Strings(sorted) + assert.Equal(t, sorted, fields, "CSV header fields should be sorted") +} + +func TestPipelineCustomTemplate(t *testing.T) { + output := captureStdoutRun(t, ` +global: + output: + outputter: stdout + outputTemplate: mytemplate +samples: + - name: customtpltest + begin: "2001-10-20 00:00:00" + end: "2001-10-20 00:00:01" + interval: 1 + count: 1 + lines: + - _raw: eventdata + host: tplhost +templates: + - name: mytemplate + header: "" + row: "HOST={{.host}} RAW={{._raw}}" + footer: "" +`) + assert.Contains(t, output, "HOST=tplhost") + assert.Contains(t, output, "RAW=eventdata") +} + +// --------------------------------------------------------------------------- +// 4. Token Processing +// --------------------------------------------------------------------------- + +func TestTokenTimestamp(t *testing.T) { + output := captureStdoutRun(t, ` +global: + output: + outputter: stdout + outputTemplate: raw +samples: + - name: tstest + begin: "2001-10-20 00:00:00" + end: "2001-10-20 00:00:01" + interval: 1 + count: 1 + tokens: + - name: ts + format: template + token: $ts$ + type: timestamp + replacement: "%d/%b/%Y %H:%M:%S" + lines: + - _raw: "$ts$" +`) + assert.Contains(t, output, "20/Oct/2001") +} + +func TestTokenRandomInt(t *testing.T) { + output := captureStdoutRun(t, ` +global: + output: + outputter: stdout + outputTemplate: raw +samples: + - name: randinttest + begin: "2001-10-20 00:00:00" + end: "2001-10-20 00:00:01" + interval: 1 + count: 10 + tokens: + - name: randnum + format: template + token: $randnum$ + type: random + replacement: int + lower: 10 + upper: 20 + lines: + - _raw: "$randnum$" +`) + lines := strings.Split(output, "\n") + assert.Equal(t, 10, len(lines), "expected 10 lines") + for _, line := range lines { + val, err := strconv.Atoi(line) + assert.NoError(t, err, "output should be an integer") + assert.GreaterOrEqual(t, val, 10) + assert.Less(t, val, 20) // upper is exclusive in randgen.Intn + } +} + +func TestTokenChoice(t *testing.T) { + output := captureStdoutRun(t, ` +global: + output: + outputter: stdout + outputTemplate: raw +samples: + - name: choicetest + begin: "2001-10-20 00:00:00" + end: "2001-10-20 00:00:01" + interval: 1 + count: 1 + tokens: + - name: color + format: template + token: $color$ + type: choice + choice: + - red + - green + - blue + lines: + - _raw: "$color$" +`) + choices := []string{"red", "green", "blue"} + assert.Contains(t, choices, output, "output should be one of the choices") +} + +// --------------------------------------------------------------------------- +// 5. HEC Transform +// --------------------------------------------------------------------------- + +func TestHECTransformDirect(t *testing.T) { + event := map[string]string{ + "_raw": "hello", + "_time": "12345", + "host": "foo", + } + template.TransformHECFields(event) + + assert.Equal(t, "hello", event["event"]) + assert.Equal(t, "12345", event["time"]) + assert.Equal(t, "foo", event["host"]) + _, hasRaw := event["_raw"] + _, hasTime := event["_time"] + assert.False(t, hasRaw, "_raw should be deleted") + assert.False(t, hasTime, "_time should be deleted") +} + +func TestHECTransformNoOp(t *testing.T) { + event := map[string]string{ + "host": "bar", + "source": "baz", + } + original := make(map[string]string) + for k, v := range event { + original[k] = v + } + template.TransformHECFields(event) + assert.Equal(t, original, event, "map should be unchanged when _raw and _time are absent") +} + +// --------------------------------------------------------------------------- +// 6. ROT Synchronization +// --------------------------------------------------------------------------- + +func TestROTReinitAfterReadFinal(t *testing.T) { + // First cycle + outputter.Mutex.Lock() + outputter.BytesWritten = make(map[string]int64) + outputter.EventsWritten = make(map[string]int64) + outputter.Mutex.Unlock() + + dummyConfig := &config.Config{ + Global: config.Global{ROTInterval: 1}, + } + + outputter.InitROT(dummyConfig) + outputter.Account(10, 100, "s1") + outputter.ReadFinal() + + outputter.Mutex.RLock() + assert.Equal(t, int64(10), outputter.EventsWritten["s1"]) + assert.Equal(t, int64(100), outputter.BytesWritten["s1"]) + outputter.Mutex.RUnlock() + + // Second cycle — validates sync.Once reset works + outputter.Mutex.Lock() + outputter.BytesWritten = make(map[string]int64) + outputter.EventsWritten = make(map[string]int64) + outputter.Mutex.Unlock() + + outputter.InitROT(dummyConfig) + outputter.Account(20, 200, "s2") + outputter.ReadFinal() + + outputter.Mutex.RLock() + assert.Equal(t, int64(20), outputter.EventsWritten["s2"]) + assert.Equal(t, int64(200), outputter.BytesWritten["s2"]) + outputter.Mutex.RUnlock() +} + +// --------------------------------------------------------------------------- +// 7. HTTP Helpers +// --------------------------------------------------------------------------- + +func TestHTTPErrorFormatting(t *testing.T) { + e := &config.HTTPError{ + StatusCode: 404, + URL: "http://example.com/test", + Body: "not found", + } + assert.Contains(t, e.Error(), "404") + assert.Contains(t, e.Error(), "http://example.com/test") + assert.Contains(t, e.Error(), "not found") + assert.True(t, e.IsNotFound()) + + e500 := &config.HTTPError{StatusCode: 500, URL: "http://x", Body: "err"} + assert.False(t, e500.IsNotFound()) +} + +func TestDoGetSuccess(t *testing.T) { + // We test through List() since doGet is unexported. + // Set up a mock server that returns a valid /v1/list response. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/v1/list", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + fmt.Fprint(w, `{"Items":[{"gogen":"test1","description":"desc1"}]}`) + })) + defer server.Close() + + os.Setenv("GOGEN_APIURL", server.URL) + defer os.Unsetenv("GOGEN_APIURL") + + list, err := config.List() + assert.NoError(t, err) + assert.Len(t, list, 1) + assert.Equal(t, "test1", list[0].Gogen) + assert.Equal(t, "desc1", list[0].Description) +} + +func TestDoGetHTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + fmt.Fprint(w, "not found") + })) + defer server.Close() + + os.Setenv("GOGEN_APIURL", server.URL) + defer os.Unsetenv("GOGEN_APIURL") + + _, err := config.List() + assert.Error(t, err) + + var httpErr *config.HTTPError + assert.True(t, errors.As(err, &httpErr), "should unwrap to *HTTPError") + assert.True(t, httpErr.IsNotFound()) +} + +func TestListSendsDefaultHeaders(t *testing.T) { + // Verify that List sends standard HTTP headers (User-Agent) to the server. + var receivedUA string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedUA = r.UserAgent() + w.WriteHeader(200) + fmt.Fprint(w, `{"Items":[]}`) + })) + defer server.Close() + + os.Setenv("GOGEN_APIURL", server.URL) + defer os.Unsetenv("GOGEN_APIURL") + + _, err := config.List() + assert.NoError(t, err) + // Go's default HTTP client sends a User-Agent header + assert.NotEmpty(t, receivedUA) +} + +func TestListWithMockAPI(t *testing.T) { + // Multiple items + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + fmt.Fprint(w, `{"Items":[ + {"gogen":"g1","description":"d1"}, + {"gogen":"g2","description":"d2"}, + {"notgogen":"bad"} + ]}`) + })) + defer server.Close() + + os.Setenv("GOGEN_APIURL", server.URL) + defer os.Unsetenv("GOGEN_APIURL") + + list, err := config.List() + assert.NoError(t, err) + assert.Len(t, list, 2, "should skip items missing gogen or description") + assert.Equal(t, "g1", list[0].Gogen) + assert.Equal(t, "g2", list[1].Gogen) +} + +func TestListWithMockAPIServerError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + fmt.Fprint(w, "internal server error") + })) + defer server.Close() + + os.Setenv("GOGEN_APIURL", server.URL) + defer os.Unsetenv("GOGEN_APIURL") + + list, err := config.List() + assert.Nil(t, list) + assert.Error(t, err, "should return error, not panic") +} + +// --------------------------------------------------------------------------- +// 8. String Safety +// --------------------------------------------------------------------------- + +func TestShortConfigStringNoPanic(t *testing.T) { + resetState() + // Verify that a short FullConfig string exercises strings.HasPrefix + // without panicking. Previously [0:4] on a short string would panic. + // We create a real (empty) file so os.Stat passes and the HasPrefix + // check is actually reached. + tmpfile, err := os.CreateTemp("", "ab") + assert.NoError(t, err) + tmpfile.Write([]byte("samples: []\n")) + tmpfile.Close() + defer os.Remove(tmpfile.Name()) + + // Rename to a 2-char basename path in the same dir — but temp dir + // paths are long. Instead, just verify BuildConfig works with the + // temp file (which has a long path but the HasPrefix("http") check + // is what matters — it's safe for any length string). + defer func() { + r := recover() + if r != nil { + msg := fmt.Sprintf("%v", r) + assert.NotContains(t, msg, "index out of range", + "should not panic with index out of range on short string") + } + }() + + config.BuildConfig(config.ConfigConfig{FullConfig: tmpfile.Name()}) +} + +// --------------------------------------------------------------------------- +// 9. Multi-Sample and FromSample +// --------------------------------------------------------------------------- + +func TestFromSampleCopy(t *testing.T) { + resetState() + configStr := ` +global: + output: + outputter: devnull +samples: + - name: sampleA + begin: "2001-10-20 00:00:00" + end: "2001-10-20 00:00:01" + interval: 1 + count: 1 + lines: + - _raw: "line from A" + - name: sampleB + fromSample: sampleA + count: 2 +` + config.SetupFromString(configStr) + c := config.NewConfig() + defer config.CleanupConfigAndEnvironment() + + b := c.FindSampleByName("sampleB") + if !assert.NotNil(t, b, "sampleB should exist") { + return + } + assert.Equal(t, "sampleB", b.Name) + assert.Equal(t, 2, b.Count) + assert.Len(t, b.Lines, 1) + assert.Equal(t, "line from A", b.Lines[0]["_raw"]) +} + +func TestMultipleSamplesEndToEnd(t *testing.T) { + configStr := ` +global: + output: + outputter: devnull +samples: + - name: multi1 + endIntervals: 1 + interval: 1 + count: 1 + lines: + - _raw: "event1" + - name: multi2 + endIntervals: 1 + interval: 1 + count: 1 + lines: + - _raw: "event2" + - name: multi3 + endIntervals: 1 + interval: 1 + count: 1 + lines: + - _raw: "event3" +` + runDevnull(t, configStr) + + outputter.Mutex.RLock() + defer outputter.Mutex.RUnlock() + + for _, name := range []string{"multi1", "multi2", "multi3"} { + assert.Greater(t, outputter.EventsWritten[name], int64(0), + "expected events for sample %s", name) + } +} + +// --------------------------------------------------------------------------- +// 10. Replay Generator +// --------------------------------------------------------------------------- + +func TestReplayGeneratorEndToEnd(t *testing.T) { + resetState() + configStr := ` +global: + output: + outputter: buf +samples: + - name: replaytest + generator: replay + begin: "2001-10-20 12:00:00" + end: "2001-10-20 12:00:49" + tokens: + - name: ts1 + type: timestamp + replacement: "%Y-%m-%dT%H:%M:%S" + format: regex + token: "(\\d{4}-\\d{2}-\\d{2}T\\d{2}:\\d{2}:\\d{2})" + lines: + - _raw: "2001-10-20T12:00:00" + - _raw: "2001-10-20T12:00:01" + - _raw: "2001-10-20T12:00:06" + - _raw: "2001-10-20T12:00:16" + - _raw: "2001-10-20T12:00:36" +` + config.SetupFromString(configStr) + c := config.NewConfig() + defer config.CleanupConfigAndEnvironment() + + s := c.FindSampleByName("replaytest") + if !assert.NotNil(t, s, "replaytest sample should exist") { + return + } + assert.False(t, s.Disabled, "sample should not be disabled") + assert.Len(t, s.ReplayOffsets, 5) + + run.Run(c) + output := c.Buf.String() + assert.NotEmpty(t, output, "replay should produce output") + + lines := strings.Split(strings.TrimSpace(output), "\n") + assert.Equal(t, 5, len(lines), "should produce 5 events") +} + +// --------------------------------------------------------------------------- +// Race detector test for ROT +// --------------------------------------------------------------------------- + +func TestROTConcurrentAccounting(t *testing.T) { + outputter.Mutex.Lock() + outputter.BytesWritten = make(map[string]int64) + outputter.EventsWritten = make(map[string]int64) + outputter.Mutex.Unlock() + + dummyConfig := &config.Config{ + Global: config.Global{ROTInterval: 1}, + } + + outputter.InitROT(dummyConfig) + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + outputter.Account(1, 10, fmt.Sprintf("concurrent_%d", n)) + }(i) + } + wg.Wait() + outputter.ReadFinal() + + outputter.Mutex.RLock() + total := int64(0) + for _, v := range outputter.EventsWritten { + total += v + } + outputter.Mutex.RUnlock() + assert.Equal(t, int64(10), total, "all 10 concurrent accounts should be recorded") +} diff --git a/tests/network_test.go b/tests/network_test.go index 0bef478..aacaf27 100644 --- a/tests/network_test.go +++ b/tests/network_test.go @@ -414,11 +414,12 @@ samples: } } - // Full message format validation + // Full message format validation (trim trailing newline from network output) rfc5424Regex := regexp.MustCompile(`^<14>1\s+\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:[-+]\d{2}:\d{2}|Z)\s+gogen\s+gogen\s+12345\s+-\s+\[meta\s+(?:[a-zA-Z0-9_]+="[^"]*"\s*)+\]\s+test message$`) - if !rfc5424Regex.Match(lastNetworkData) { - t.Errorf("RFC5424 format mismatch. Got: %s", string(lastNetworkData)) + trimmedData := strings.TrimRight(string(lastNetworkData), "\n") + if !rfc5424Regex.MatchString(trimmedData) { + t.Errorf("RFC5424 format mismatch. Got: %s", trimmedData) } // Validate meta fields diff --git a/tests/outputter_test.go b/tests/outputter_test.go index 7de908f..7d5bdb3 100644 --- a/tests/outputter_test.go +++ b/tests/outputter_test.go @@ -2,7 +2,6 @@ package tests import ( "testing" - "time" config "github.com/coccyx/gogen/internal" "github.com/coccyx/gogen/outputter" @@ -36,12 +35,12 @@ func TestReadFinalSynchronization(t *testing.T) { // Add other necessary dummy fields for config if ROT accesses them } - // Initialize the outputter system (starts readStats goroutine) - // Run ROT in a goroutine as it contains an infinite loop for periodic stats - go outputter.ROT(dummyConfig) - // Give the ROT goroutine a moment to start up and initialize rotchan - // Adjust duration if needed, but keep it short for test speed. - time.Sleep(10 * time.Millisecond) + // Initialize the outputter channel and readStats goroutine. + // We intentionally do NOT start ROT() here — it runs an infinite loop + // and would leak goroutines between test iterations. The synchronization + // being tested (Account -> rotchan -> readStats -> ReadFinal/WaitGroup) + // only requires InitROT. + outputter.InitROT(dummyConfig) // --- Action --- testSampleName := "test_sync_sample" diff --git a/timer/timer_test.go b/timer/timer_test.go index b93ac0c..139e98f 100644 --- a/timer/timer_test.go +++ b/timer/timer_test.go @@ -199,6 +199,40 @@ Loop: } } +func TestBackfillReplay(t *testing.T) { + os.Setenv("GOGEN_HOME", "..") + os.Setenv("GOGEN_ALWAYS_REFRESH", "1") + home := filepath.Join("..", "tests", "timer") + os.Setenv("GOGEN_SAMPLES_DIR", home) + + s := tests.FindSampleInFile(home, "realtimereplay") + // Force non-realtime mode to exercise backfill with replay generator + s.Realtime = false + // Set end slightly in the future so backfill runs a few iterations + s.EndParsed = s.Current.Add(5 * time.Second) + + gq := make(chan *config.GenQueueItem, 1000) + oq := make(chan *config.OutQueueItem) + done := make(chan int) + gqs := make([]*config.GenQueueItem, 0, 10) + + timer := &Timer{S: s, GQ: gq, OQ: oq, Done: done} + go timer.NewTimer(0) + <-done + +Loop: + for { + select { + case i := <-gq: + gqs = append(gqs, i) + default: + break Loop + } + } + // Should have generated events using replay offsets via inc() replay path + assert.Greater(t, len(gqs), 0, "should have generated replay events during backfill") +} + func TestTimerClose(t *testing.T) { os.Setenv("GOGEN_HOME", "..") os.Setenv("GOGEN_ALWAYS_REFRESH", "1") diff --git a/ui/.env b/ui/.env index e01c8fb..4e25a18 100644 --- a/ui/.env +++ b/ui/.env @@ -1 +1,4 @@ -# API URL for the Gogen API\nVITE_API_URL=/api +# API URL for the Gogen API +VITE_API_URL=/api +VITE_GITHUB_CLIENT_ID=your_dev_client_id_here +VITE_GITHUB_REDIRECT_URI=http://localhost:3000/auth/callback diff --git a/ui/.env.development b/ui/.env.development index e710af6..112f0bd 100644 --- a/ui/.env.development +++ b/ui/.env.development @@ -1 +1,3 @@ -VITE_API_URL=/api # This will be proxied by Vite to localhost:4000 \ No newline at end of file +VITE_API_URL=/api # This will be proxied by Vite to localhost:4000 +VITE_GITHUB_CLIENT_ID=Ov23ligZv86QA3hqKF13 +VITE_GITHUB_REDIRECT_URI=http://localhost:3000/auth/callback \ No newline at end of file diff --git a/ui/.env.production b/ui/.env.production index 7ce481b..573a9e8 100644 --- a/ui/.env.production +++ b/ui/.env.production @@ -1 +1,3 @@ -VITE_API_URL=https://api.gogen.io/v1 \ No newline at end of file +VITE_API_URL=https://api.gogen.io/v1 +VITE_GITHUB_CLIENT_ID=Ov23lisDzJZ0q5iiBqA9 +VITE_GITHUB_REDIRECT_URI=https://gogen.io/auth/callback \ No newline at end of file diff --git a/ui/.env.staging b/ui/.env.staging index 6bdd7a8..dd42135 100644 --- a/ui/.env.staging +++ b/ui/.env.staging @@ -1 +1,3 @@ -VITE_API_URL=https://staging-api.gogen.io/v1 \ No newline at end of file +VITE_API_URL=/api +VITE_GITHUB_CLIENT_ID=Ov23liep3eAw002qddGU +VITE_GITHUB_REDIRECT_URI=https://staging.gogen.io/auth/callback \ No newline at end of file diff --git a/ui/index.html b/ui/index.html index 44ef93d..4f39798 100644 --- a/ui/index.html +++ b/ui/index.html @@ -6,6 +6,9 @@ Gogen UI + + + diff --git a/ui/jest.config.ts b/ui/jest.config.ts index a711efb..45c4340 100644 --- a/ui/jest.config.ts +++ b/ui/jest.config.ts @@ -13,7 +13,6 @@ const config: Config = { }, testMatch: ['**/__tests__/**/*.[jt]s?(x)', '**/?(*.)+(spec|test).[jt]s?(x)'], moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'], - collectCoverage: true, collectCoverageFrom: [ 'src/**/*.{ts,tsx}', '!src/**/*.d.ts', @@ -30,4 +29,4 @@ const config: Config = { }, }; -export default config; \ No newline at end of file +export default config; diff --git a/ui/src/App.test.tsx b/ui/src/App.test.tsx index c66d40d..3ec213b 100644 --- a/ui/src/App.test.tsx +++ b/ui/src/App.test.tsx @@ -47,23 +47,23 @@ describe('App Component', () => { ); }; - it('renders the layout component', () => { + it('renders the layout component', async () => { renderWithRouter(); - expect(screen.getByTestId('mock-layout')).toBeInTheDocument(); + expect(await screen.findByTestId('mock-layout')).toBeInTheDocument(); }); - it('renders home page on root path', () => { + it('renders home page on root path', async () => { renderWithRouter(['/']); - expect(screen.getByTestId('mock-home-page')).toBeInTheDocument(); + expect(await screen.findByTestId('mock-home-page')).toBeInTheDocument(); }); - it('renders configuration detail page on configuration path', () => { + it('renders configuration detail page on configuration path', async () => { renderWithRouter(['/configurations/owner/config-name']); - expect(screen.getByTestId('mock-config-detail-page')).toBeInTheDocument(); + expect(await screen.findByTestId('mock-config-detail-page')).toBeInTheDocument(); }); - it('renders not found page for unknown routes', () => { + it('renders not found page for unknown routes', async () => { renderWithRouter(['/unknown-route']); - expect(screen.getByTestId('mock-not-found-page')).toBeInTheDocument(); + expect(await screen.findByTestId('mock-not-found-page')).toBeInTheDocument(); }); -}); \ No newline at end of file +}); diff --git a/ui/src/App.tsx b/ui/src/App.tsx index a73daaf..d83eb96 100644 --- a/ui/src/App.tsx +++ b/ui/src/App.tsx @@ -1,21 +1,60 @@ +import { Suspense, lazy } from 'react'; import { BrowserRouter as Router, Routes, Route } from 'react-router-dom'; +import { AuthProvider } from './context/AuthContext'; import Layout from './components/Layout'; -import HomePage from './pages/HomePage'; -import ConfigurationDetailPage from './pages/ConfigurationDetailPage'; -import NotFoundPage from './pages/NotFoundPage'; +import ProtectedRoute from './components/ProtectedRoute'; +import LoadingSpinner from './components/LoadingSpinner'; + +const HomePage = lazy(() => import('./pages/HomePage')); +const ConfigurationDetailPage = lazy(() => import('./pages/ConfigurationDetailPage')); +const LoginPage = lazy(() => import('./pages/LoginPage')); +const AuthCallbackPage = lazy(() => import('./pages/AuthCallbackPage')); +const MyConfigurationsPage = lazy(() => import('./pages/MyConfigurationsPage')); +const EditConfigurationPage = lazy(() => import('./pages/EditConfigurationPage')); +const NotFoundPage = lazy(() => import('./pages/NotFoundPage')); function App() { return ( - - - - } /> - } /> - } /> - - - + + + + }> + + } /> + } /> + } /> + } /> + + + + } + /> + + + + } + /> + + + + } + /> + } /> + + + + + ); } -export default App; \ No newline at end of file +export default App; diff --git a/ui/src/api/gogenApi.ts b/ui/src/api/gogenApi.ts index 1a63d6d..6e87aa3 100644 --- a/ui/src/api/gogenApi.ts +++ b/ui/src/api/gogenApi.ts @@ -9,10 +9,20 @@ const apiClient = axios.create({ }, }); +// Add request interceptor to include auth token +apiClient.interceptors.request.use((requestConfig) => { + const token = localStorage.getItem('github_token'); + if (token) { + requestConfig.headers.Authorization = `token ${token}`; + } + return requestConfig; +}); + // Define interfaces for API responses export interface ConfigurationSummary { gogen: string; description: string; + owner?: string; } export interface Configuration extends ConfigurationSummary { @@ -23,6 +33,25 @@ export interface Configuration extends ConfigurationSummary { generators?: any[]; global?: any; templates?: any[]; + s3Path?: string; +} + +export interface OAuthResponse { + access_token: string; + token_type: string; + user: { + login: string; + avatar_url: string; + id: number; + name?: string; + email?: string; + }; +} + +export interface UpsertRequest { + name: string; + description: string; + config: string; } // API functions @@ -61,6 +90,50 @@ export const gogenApi = { throw error; } }, + + // Exchange OAuth code for access token + exchangeOAuthCode: async (code: string, state: string): Promise => { + try { + const response = await apiClient.post('/auth/github', { code, state }); + return response.data; + } catch (error) { + console.error('Error exchanging OAuth code:', error); + throw error; + } + }, + + // Get current user's configurations + getMyConfigurations: async (): Promise => { + try { + const response = await apiClient.get('/my-configs'); + return response.data.Items || []; + } catch (error) { + console.error('Error fetching my configurations:', error); + throw error; + } + }, + + // Create or update a configuration + upsertConfiguration: async (data: UpsertRequest): Promise => { + try { + const response = await apiClient.post('/upsert', data); + return response.data; + } catch (error) { + console.error('Error upserting configuration:', error); + throw error; + } + }, + + // Delete a configuration + deleteConfiguration: async (configPath: string): Promise => { + try { + const response = await apiClient.delete(`/delete/${configPath}`); + return response.data; + } catch (error) { + console.error(`Error deleting configuration ${configPath}:`, error); + throw error; + } + }, }; -export default gogenApi; \ No newline at end of file +export default gogenApi; diff --git a/ui/src/components/ConfigurationList.test.tsx b/ui/src/components/ConfigurationList.test.tsx index f386ce2..0adbd90 100644 --- a/ui/src/components/ConfigurationList.test.tsx +++ b/ui/src/components/ConfigurationList.test.tsx @@ -62,11 +62,11 @@ describe('ConfigurationList', () => { // Check header styling const headers = screen.getAllByRole('columnheader'); headers.forEach(header => { - expect(header).toHaveClass('px-6', 'py-3', 'text-left', 'text-xs', 'font-medium', 'text-gray-500', 'uppercase', 'tracking-wider'); + expect(header).toHaveClass('px-6', 'py-2', 'text-left', 'text-xs', 'font-medium', 'text-term-text-muted', 'uppercase', 'tracking-wider'); }); // Check table container styling const tableContainer = screen.getByRole('table').closest('div'); - expect(tableContainer).toHaveClass('bg-white', 'rounded-lg', 'shadow', 'overflow-hidden'); + expect(tableContainer).toHaveClass('bg-term-bg-elevated', 'rounded', 'border', 'border-term-border', 'overflow-hidden'); }); -}); \ No newline at end of file +}); diff --git a/ui/src/components/ConfigurationList.tsx b/ui/src/components/ConfigurationList.tsx index 8ab1815..147f905 100644 --- a/ui/src/components/ConfigurationList.tsx +++ b/ui/src/components/ConfigurationList.tsx @@ -17,7 +17,7 @@ const ConfigurationList = ({ configurations, loading, error }: ConfigurationList // Filter and sort configurations const filteredAndSortedConfigs = useMemo(() => { return configurations - .filter(config => + .filter(config => config.gogen.toLowerCase().includes(searchQuery.toLowerCase()) || (config.description || '').toLowerCase().includes(searchQuery.toLowerCase()) ) @@ -36,12 +36,12 @@ const ConfigurationList = ({ configurations, loading, error }: ConfigurationList }; if (loading) return ; - if (error) return
{error}
; + if (error) return
{error}
; return (
{/* Search Filter */} -
+
{/* Results count */} -
+
Showing {Math.min(filteredAndSortedConfigs.length, itemsPerPage)} of {filteredAndSortedConfigs.length} configurations
{/* Configurations Table */} -
+
- + - - - + {paginatedConfigs.map((config) => ( - - + - ))} @@ -94,15 +94,15 @@ const ConfigurationList = ({ configurations, loading, error }: ConfigurationList {/* Pagination */} {totalPages > 1 && ( -
-
+ Name + Description
+
{config.gogen} -
{config.description || '-'}
+
+
{config.description || '-'}