Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 109 additions & 1 deletion cmd/engram/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"strings"
"syscall"

"github.com/Gentleman-Programming/engram/internal/embedding"
"github.com/Gentleman-Programming/engram/internal/mcp"
"github.com/Gentleman-Programming/engram/internal/project"
"github.com/Gentleman-Programming/engram/internal/server"
Expand Down Expand Up @@ -162,6 +163,8 @@ func main() {
cmdProjects(cfg)
case "setup":
cmdSetup()
case "backfill-embeddings":
cmdBackfillEmbeddings(cfg)
case "version", "--version", "-v":
fmt.Printf("engram %s\n", version)
case "help", "--help", "-h":
Expand Down Expand Up @@ -195,6 +198,8 @@ func cmdServe(cfg store.Config) {
}
defer s.Close()

configureEmbeddings(s, "", "", "")

srv := newHTTPServer(s, port)

// Graceful shutdown on SIGINT/SIGTERM.
Expand All @@ -212,9 +217,12 @@ func cmdServe(cfg store.Config) {
}

func cmdMCP(cfg store.Config) {
// Parse --tools and --project flags
// Parse --tools, --project, and --embedding-* flags
toolsFilter := ""
projectOverride := ""
embProvider := ""
embModel := ""
embURL := ""
for i := 2; i < len(os.Args); i++ {
if strings.HasPrefix(os.Args[i], "--tools=") {
toolsFilter = strings.TrimPrefix(os.Args[i], "--tools=")
Expand All @@ -226,6 +234,21 @@ func cmdMCP(cfg store.Config) {
} else if os.Args[i] == "--project" && i+1 < len(os.Args) {
projectOverride = os.Args[i+1]
i++
} else if strings.HasPrefix(os.Args[i], "--embedding-provider=") {
embProvider = strings.TrimPrefix(os.Args[i], "--embedding-provider=")
} else if os.Args[i] == "--embedding-provider" && i+1 < len(os.Args) {
embProvider = os.Args[i+1]
i++
} else if strings.HasPrefix(os.Args[i], "--embedding-model=") {
embModel = strings.TrimPrefix(os.Args[i], "--embedding-model=")
} else if os.Args[i] == "--embedding-model" && i+1 < len(os.Args) {
embModel = os.Args[i+1]
i++
} else if strings.HasPrefix(os.Args[i], "--embedding-url=") {
embURL = strings.TrimPrefix(os.Args[i], "--embedding-url=")
} else if os.Args[i] == "--embedding-url" && i+1 < len(os.Args) {
embURL = os.Args[i+1]
i++
}
}

Expand All @@ -248,6 +271,8 @@ func cmdMCP(cfg store.Config) {
}
defer s.Close()

configureEmbeddings(s, embProvider, embModel, embURL)

mcpCfg := mcp.MCPConfig{
DefaultProject: detectedProject,
}
Expand All @@ -260,6 +285,42 @@ func cmdMCP(cfg store.Config) {
}
}

// configureEmbeddings sets up an embedding provider on the store.
// CLI flags take precedence over environment variables.
func configureEmbeddings(s *store.Store, provider, model, url string) {
// Environment variable fallbacks
if provider == "" {
provider = os.Getenv("ENGRAM_EMBEDDING_PROVIDER")
}
if model == "" {
model = os.Getenv("ENGRAM_EMBEDDING_MODEL")
}
if url == "" {
url = os.Getenv("ENGRAM_EMBEDDING_URL")
}

if provider == "" || provider == "none" {
return
}

embCfg := embedding.Config{
Provider: provider,
Model: model,
URL: url,
APIKey: os.Getenv("ENGRAM_EMBEDDING_API_KEY"),
}

emb, err := embedding.NewProvider(embCfg)
if err != nil {
log.Printf("[engram] embedding provider setup failed: %v", err)
return
}
if emb != nil {
s.SetEmbeddingProvider(emb)
log.Printf("[engram] embedding provider: %s (model: %s)", provider, emb.ModelName())
}
}

func cmdTUI(cfg store.Config) {
s, err := storeNew(cfg)
if err != nil {
Expand Down Expand Up @@ -726,6 +787,53 @@ func cmdSync(cfg store.Config) {
fmt.Printf(" git add .engram/ && git commit -m \"sync engram memories\"\n")
}

func cmdBackfillEmbeddings(cfg store.Config) {
batchSize := 50
embProvider := ""
embModel := ""
embURL := ""

for i := 2; i < len(os.Args); i++ {
if strings.HasPrefix(os.Args[i], "--batch-size=") {
if n, err := strconv.Atoi(strings.TrimPrefix(os.Args[i], "--batch-size=")); err == nil {
batchSize = n
}
} else if strings.HasPrefix(os.Args[i], "--embedding-provider=") {
embProvider = strings.TrimPrefix(os.Args[i], "--embedding-provider=")
} else if strings.HasPrefix(os.Args[i], "--embedding-model=") {
embModel = strings.TrimPrefix(os.Args[i], "--embedding-model=")
} else if strings.HasPrefix(os.Args[i], "--embedding-url=") {
embURL = strings.TrimPrefix(os.Args[i], "--embedding-url=")
}
}

s, err := storeNew(cfg)
if err != nil {
fatal(err)
}
defer s.Close()

configureEmbeddings(s, embProvider, embModel, embURL)

if s.EmbeddingProvider() == nil {
fmt.Fprintln(os.Stderr, "error: no embedding provider configured")
fmt.Fprintln(os.Stderr, " set --embedding-provider=ollama or ENGRAM_EMBEDDING_PROVIDER=ollama")
exitFunc(1)
return
}

fmt.Fprintf(os.Stderr, "Backfilling embeddings (batch size: %d, provider: %s)...\n", batchSize, s.EmbeddingProvider().ModelName())

if err := s.BackfillEmbeddings(batchSize, func(done, total int) {
fmt.Fprintf(os.Stderr, "\r %d / %d observations embedded", done, total)
}); err != nil {
fmt.Fprintln(os.Stderr)
fatal(err)
}

fmt.Fprintln(os.Stderr, "\nDone.")
}

func cmdProjects(cfg store.Config) {
// Route: engram projects list | engram projects consolidate [--all] [--dry-run]
subCmd := "list"
Expand Down
119 changes: 119 additions & 0 deletions internal/embedding/ollama.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package embedding

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
)

// OllamaProvider generates embeddings via the Ollama REST API.
type OllamaProvider struct {
url string
model string
dims int
client *http.Client
}

type ollamaRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
}

type ollamaResponse struct {
Embedding []float64 `json:"embedding"`
}

// NewOllamaProvider creates a provider that calls Ollama's /api/embeddings endpoint.
// The dimensions are probed on first call and cached.
func NewOllamaProvider(url, model string) (*OllamaProvider, error) {
return &OllamaProvider{
url: url,
model: model,
client: &http.Client{},
}, nil
}

func (p *OllamaProvider) Embed(ctx context.Context, text string) ([]float32, error) {
body, err := json.Marshal(ollamaRequest{
Model: p.model,
Prompt: text,
})
if err != nil {
return nil, fmt.Errorf("ollama: marshal request: %w", err)
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url+"/api/embeddings", bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("ollama: create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")

resp, err := p.client.Do(req)
if err != nil {
return nil, fmt.Errorf("ollama: request failed: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("ollama: HTTP %d: %s", resp.StatusCode, string(respBody))
}

var result ollamaResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("ollama: decode response: %w", err)
}

if len(result.Embedding) == 0 {
return nil, fmt.Errorf("ollama: empty embedding returned")
}

// Convert float64 to float32
vec := make([]float32, len(result.Embedding))
for i, v := range result.Embedding {
vec[i] = float32(v)
}

// Cache dimensions from first successful response
if p.dims == 0 {
p.dims = len(vec)
}

return vec, nil
}

func (p *OllamaProvider) Dimensions() int {
return p.dims
}

func (p *OllamaProvider) ModelName() string {
return p.model
}

// MaxChars returns a conservative character limit based on the model's token context.
// Ollama models vary widely: nomic-embed-text=8192 tokens, mxbai-embed-large=512 tokens.
func (p *OllamaProvider) MaxChars() int {
return ollamaModelMaxChars(p.model)
}

// ollamaModelMaxChars returns the max character limit for known Ollama embedding models.
// Token-to-char ratios vary wildly: English prose ~4 chars/token, but markdown with
// code blocks, pipes, and special characters can be ~1.5 chars/token. We use empirically
// tested limits that work with real-world mixed content.
func ollamaModelMaxChars(model string) int {
// Empirically tested max chars for known models (real markdown/code content).
known := map[string]int{
"nomic-embed-text": 6000, // 8192 tokens, tested with markdown/code
"mxbai-embed-large": 500, // 512 tokens, very limited
"all-minilm": 250, // 256 tokens
"snowflake-arctic-embed": 500, // 512 tokens
}
if maxChars, ok := known[model]; ok {
return maxChars
}
// Unknown model — conservative default.
return 6000
}
108 changes: 108 additions & 0 deletions internal/embedding/openai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package embedding

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
)

// OpenAIProvider generates embeddings via the OpenAI API.
type OpenAIProvider struct {
apiKey string
model string
dims int
client *http.Client
}

type openAIRequest struct {
Model string `json:"model"`
Input string `json:"input"`
}

type openAIResponse struct {
Data []struct {
Embedding []float64 `json:"embedding"`
} `json:"data"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}

// NewOpenAIProvider creates a provider that calls the OpenAI embeddings API.
func NewOpenAIProvider(apiKey, model string) (*OpenAIProvider, error) {
return &OpenAIProvider{
apiKey: apiKey,
model: model,
client: &http.Client{},
}, nil
}

func (p *OpenAIProvider) Embed(ctx context.Context, text string) ([]float32, error) {
body, err := json.Marshal(openAIRequest{
Model: p.model,
Input: text,
})
if err != nil {
return nil, fmt.Errorf("openai: marshal request: %w", err)
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://api.openai.com/v1/embeddings", bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("openai: create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.apiKey)

resp, err := p.client.Do(req)
if err != nil {
return nil, fmt.Errorf("openai: request failed: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("openai: HTTP %d: %s", resp.StatusCode, string(respBody))
}

var result openAIResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("openai: decode response: %w", err)
}

if result.Error != nil {
return nil, fmt.Errorf("openai: API error: %s", result.Error.Message)
}

if len(result.Data) == 0 || len(result.Data[0].Embedding) == 0 {
return nil, fmt.Errorf("openai: empty embedding returned")
}

// Convert float64 to float32
vec := make([]float32, len(result.Data[0].Embedding))
for i, v := range result.Data[0].Embedding {
vec[i] = float32(v)
}

if p.dims == 0 {
p.dims = len(vec)
}

return vec, nil
}

func (p *OpenAIProvider) Dimensions() int {
return p.dims
}

func (p *OpenAIProvider) ModelName() string {
return p.model
}

// MaxChars returns a conservative character limit for OpenAI embedding models.
// All current OpenAI embedding models support 8,191 tokens.
func (p *OpenAIProvider) MaxChars() int {
return 8191 * 2 // ~2 chars per token for mixed code/prose
}
Loading