diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..f0abcd3 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,4 @@ +.git +ollama_storage +bin/ +*.exe \ No newline at end of file diff --git a/.env.example b/.env.example index 96e8af6..01773c1 100644 --- a/.env.example +++ b/.env.example @@ -3,4 +3,15 @@ EMBEDDING_DIM=768 SERVER_ADDR=:8080 OLLAMA_ADDR=11434 MAX_MEMORY_DISTANCE=0.5 -TOP_K_MEMORIES=10 \ No newline at end of file +TOP_K_MEMORIES=10 + +SUMMARIZATION_MODEL=deepseek-r1:1.5b +LLM_MODEL=deepseek-r1:1.5b +WEAVIATE_HOST=weaviate:8080 +WEAVIATE_SCHEME=http + +# Summarization settings +SUMMARY_THRESHOLD=2 # Trigger summarization after N memories +SUMMARY_BATCH_SIZE=2 # Number of memories to summarize at once +SUMMARY_MAX_AGE_DAYS=3000 # Only summarize memories older than this +ENABLE_AUTO_SUMMARY=true # Enable automatic summarization diff --git a/.gitignore b/.gitignore index 4c842e4..a95e81c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .env -.DS_Store \ No newline at end of file +.DS_Store +ollama_storage/ diff --git a/Dockerfile.ollama b/Dockerfile.ollama index 062a383..1f0befd 100644 --- a/Dockerfile.ollama +++ b/Dockerfile.ollama @@ -1,5 +1,10 @@ -# Base Ollama image FROM ollama/ollama:latest -# Start the Ollama server by default -ENTRYPOINT ["ollama", "serve"] +# We'll use the default path during build +RUN ollama serve & \ + sleep 5 && \ + ollama pull nomic-embed-text && \ + ollama pull deepseek-r1:1.5b && \ + pkill ollama + +ENTRYPOINT ["ollama", "serve"] \ No newline at end of file diff --git a/cmd/server/main.go b/cmd/server/main.go index 196034d..9bb8be0 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -1,12 +1,17 @@ package main import ( + "encoding/json" + "fmt" "log" "net/http" "os" - "strconv" + "time" "github.com/joho/godotenv" + "github.com/weaviate/weaviate-go-client/v4/weaviate" + + wv "github.com/sobowalebukola/memcortex/internal/db/weaviate" ollama "github.com/sobowalebukola/memcortex/internal/embedder" "github.com/sobowalebukola/memcortex/internal/handlers" "github.com/sobowalebukola/memcortex/internal/memory" @@ -14,33 +19,117 @@ import ( ) func main() { + err := godotenv.Load() if err != nil { log.Println("No .env file found, using system environment") } - dimStr := os.Getenv("EMBEDDING_DIM") - if dimStr == "" { - dimStr = "768" + + cfg := weaviate.Config{ + Host: os.Getenv("WEAVIATE_HOST"), + Scheme: os.Getenv("WEAVIATE_SCHEME"), + } + if cfg.Host == "" { + cfg.Host = "weaviate:8080" + } + if cfg.Scheme == "" { + cfg.Scheme = "http" } - dim, _ := strconv.Atoi(dimStr) - store, err := memory.NewStore("memory_idx", dim) + wClient, err := weaviate.NewClient(cfg) if err != nil { - log.Fatalf("failed to create memory store: %v", err) + log.Fatalf("failed to create weaviate client: %v", err) } + + + wv.EnsureSchema(wClient) + + emb := ollama.NewEmbeddingClient(os.Getenv("EMBEDDING_MODEL")) + + store := memory.NewWeaviateStore(wClient, "Memory_idx") + + m := memory.NewManager(store, emb) - log.Println("MemCortex initialized successfully!") + log.Println("MemCortex initialized with Weaviate successfully!") + chatHandler := handlers.NewChatHandler(m) mw := &middleware.MemoryMiddleware{Manager: m} mux := http.NewServeMux() + + mux.Handle("/chat", mw.Handler(chatHandler)) + + mux.HandleFunc("/api/summarize", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + userID := r.Header.Get("X-User-ID") + if userID == "" { + http.Error(w, "Missing X-User-ID header", http.StatusBadRequest) + return + } + + + if err := m.SummarizeUserMemories(r.Context(), userID); err != nil { + http.Error(w, fmt.Sprintf("Summarization failed: %v", err), http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "summarization completed"}) + }) + + mux.HandleFunc("/api/register", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req struct { + Username string `json:"username"` + Bio string `json:"bio"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Invalid body", http.StatusBadRequest) + return + } + + + newID := fmt.Sprintf("u-%d", time.Now().Unix()) + + + _, err := wClient.Data().Creator(). + WithClassName("User"). + WithProperties(map[string]interface{}{ + "username": req.Username, + "userId": newID, + "bio": req.Bio, + "createdAt": time.Now().Format(time.RFC3339), + }).Do(r.Context()) + + if err != nil { + log.Printf("Error saving user: %v", err) + http.Error(w, "Failed to register", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "user_id": newID, + "status": "Registration successful!", + }) + }) + + addr := os.Getenv("SERVER_ADDR") if addr == "" { addr = ":8080" diff --git a/docker-compose.yml b/docker-compose.yml index ad8f4d7..40dd787 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,11 +1,9 @@ services: ollama: - build: - context: . - dockerfile: Dockerfile.ollama + image: ollama/ollama:latest container_name: ollama ports: - - "${OLLAMA_ADDR}:11434" + - "11434:11434" restart: unless-stopped entrypoint: ["/bin/sh", "-c"] command: > @@ -14,12 +12,15 @@ services: ollama pull ${EMBEDDING_MODEL} && wait" volumes: - - /root/.ollama + - ./ollama_storage:/root/.ollama healthcheck: test: ["CMD", "ollama", "list"] interval: 10s timeout: 5s - retries: 5 + retries: 10 + start_period: 60s + + weaviate: image: semitechnologies/weaviate:1.25.3 @@ -34,7 +35,7 @@ services: DEFAULT_VECTORIZER_MODULE: "none" CLUSTER_HOSTNAME: "node1" volumes: - - /var/lib/weaviate + - weaviate_data:/var/lib/weaviate restart: unless-stopped go-server: build: @@ -43,13 +44,21 @@ services: container_name: go-server ports: - "${SERVER_ADDR}:8080" + env_file: + - .env environment: - OLLAMA_HOST=http://ollama:11434 - EMBEDDING_MODEL=nomic-embed-text - - WEAVIATE_HOST=http://weaviate:8080 + - SUMMARIZATION_MODEL=deepseek-r1:1.5b + - LLM_MODEL=deepseek-r1:1.5b + - WEAVIATE_HOST=weaviate:8080 + - SUMMARY_THRESHOLD=5 + - SUMMARY_BATCH_SIZE=2 depends_on: ollama: condition: service_healthy weaviate: condition: service_started restart: unless-stopped +volumes: + weaviate_data: diff --git a/internal/db/weaviate/weaviate.go b/internal/db/weaviate/weaviate.go new file mode 100644 index 0000000..178e628 --- /dev/null +++ b/internal/db/weaviate/weaviate.go @@ -0,0 +1,195 @@ +package weaviate + +import ( + "context" + "fmt" + "log" + "time" + "github.com/weaviate/weaviate-go-client/v4/weaviate" + "github.com/weaviate/weaviate-go-client/v4/weaviate/graphql" + "github.com/weaviate/weaviate-go-client/v4/weaviate/filters" + "github.com/weaviate/weaviate/entities/models" +) + +type WeaviateClient struct { + client *weaviate.Client +} + +func NewWeaviateClient(client *weaviate.Client) *WeaviateClient { + return &WeaviateClient{client: client} +} + +func (w *WeaviateClient) AddMemory(ctx context.Context, content string, userID string) error { + + if userID == "" { + userID = fmt.Sprintf("user_%d", time.Now().Unix()) + log.Printf("Warning: Empty userID provided. Falling back to generated ID: %s", userID) + } + + properties := map[string]interface{}{ + "content": content, + "userId": userID, + "timestamp": time.Now().Format(time.RFC3339), + "memoryType": "raw", + "isSummary": false, + } + + + fmt.Printf("--- [DB] SAVING MEMORY FOR USER: %s ---\n", userID) + + _, err := w.client.Data().Creator(). + WithClassName("Memory_idx"). + WithProperties(properties). + Do(ctx) + + if err != nil { + return fmt.Errorf("failed to create memory: %w", err) + } + + return nil +} + + +func (w *WeaviateClient) RegisterUser(ctx context.Context, username string, bio string) (string, error) { + userID := fmt.Sprintf("u-%d", time.Now().Unix()) + + properties := map[string]interface{}{ + "userId": userID, + "username": username, + "bio": bio, + "createdAt": time.Now().Format(time.RFC3339), + } + + _, err := w.client.Data().Creator(). + WithClassName("User"). + WithProperties(properties). + Do(ctx) + + if err != nil { + return "", fmt.Errorf("failed to register user: %w", err) + } + + return userID, nil +} + + +func (w *WeaviateClient) GetUserBio(ctx context.Context, userID string) (string, error) { + + if userID == "" { + return "", nil + } + + + where := filters.Where(). + WithPath([]string{"userId"}). + WithOperator(filters.Equal). + WithValueString(userID) + + result, err := w.client.GraphQL().Get(). + WithClassName("User"). + WithFields(graphql.Field{Name: "bio"}). + WithWhere(where). + Do(ctx) + if err != nil { + return "", fmt.Errorf("weaviate query failed: %w", err) + } + + // Safe navigation of the nested map + if result.Data["Get"] == nil { + return "", nil + } + + data, ok := result.Data["Get"].(map[string]interface{})["User"].([]interface{}) + if !ok || len(data) == 0 { + return "", nil + } + + user, ok := data[0].(map[string]interface{}) + if !ok { + return "", nil + } + + bio, _ := user["bio"].(string) + return bio, nil +} + +func EnsureSchema(client *weaviate.Client) { + ctx := context.Background() + + + ensureClass(client, ctx, &models.Class{ + Class: "Memory_idx", + Vectorizer: "none", + Properties: []*models.Property{ + {Name: "content", DataType: []string{"text"}}, + {Name: "userId", DataType: []string{"string"}}, + {Name: "timestamp", DataType: []string{"date"}}, + {Name: "memoryType", DataType: []string{"string"}}, + {Name: "isSummary", DataType: []string{"boolean"}}, + {Name: "originalIds", DataType: []string{"text[]"}}, + }, + }) + + + ensureClass(client, ctx, &models.Class{ + Class: "User", + Vectorizer: "none", + Properties: []*models.Property{ + {Name: "username", DataType: []string{"string"}}, + {Name: "userId", DataType: []string{"string"}}, + {Name: "bio", DataType: []string{"text"}}, + {Name: "createdAt", DataType: []string{"date"}}, + }, + }) +} + +func ensureClass(client *weaviate.Client, ctx context.Context, classObj *models.Class) { + exists, err := client.Schema().ClassExistenceChecker().WithClassName(classObj.Class).Do(ctx) + if err != nil { + log.Printf("Error checking schema for %s: %v", classObj.Class, err) + return + } + if !exists { + err := client.Schema().ClassCreator().WithClass(classObj).Do(ctx) + if err != nil { + log.Fatalf("Failed to create class %s: %v", classObj.Class, err) + } + } +} + +func (w *WeaviateClient) EnsureUser(ctx context.Context, userID string) error { + + where := filters.Where(). + WithPath([]string{"userId"}). + WithOperator(filters.Equal). + WithValueString(userID) + + result, err := w.client.GraphQL().Get(). + WithClassName("User"). + WithFields(graphql.Field{Name: "userId"}). + WithWhere(where). + Do(ctx) + if err != nil { + return fmt.Errorf("failed to check user existence: %w", err) + } + + data := result.Data["Get"].(map[string]interface{})["User"].([]interface{}) + + + if len(data) > 0 { + return nil + } + + + log.Printf("New user detected: %s. Performing registration...", userID) + _, err = w.client.Data().Creator(). + WithClassName("User"). + WithProperties(map[string]interface{}{ + "userId": userID, + "username": "User_" + userID, + "bio": "New MemCortex user", + "createdAt": time.Now().Format(time.RFC3339), + }).Do(ctx) + + return err +} \ No newline at end of file diff --git a/internal/handlers/chat.go b/internal/handlers/chat.go index 4707889..75f42c4 100644 --- a/internal/handlers/chat.go +++ b/internal/handlers/chat.go @@ -1,62 +1,151 @@ package handlers import ( + "bytes" + "context" "encoding/json" "fmt" "log" "net/http" + "os" + "strings" + "time" "github.com/sobowalebukola/memcortex/internal/memory" - "github.com/sobowalebukola/memcortex/internal/middleware" ) -type ChatReq struct { +type ChatRequest struct { Message string `json:"message"` } -type ChatResp struct { - Response string `json:"new_message"` - Memories []memory.Memory `json:"related_memories"` + +type ChatResponse struct { + Response string `json:"new_message"` + Memories []string `json:"related_memories"` } type ChatHandler struct { Manager *memory.Manager } -func NewChatHandler(m *memory.Manager) *ChatHandler { return &ChatHandler{Manager: m} } +func NewChatHandler(m *memory.Manager) *ChatHandler { + return &ChatHandler{Manager: m} +} func (h *ChatHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + userID := r.Header.Get("X-User-ID") - var req ChatReq +if userID == "" { + userID = fmt.Sprintf("user_%d", time.Now().Unix()) + log.Printf("No X-User-ID header found. Assigning dynamic ID: %s", userID) + } + + ctx := r.Context() + if err := h.Manager.EnsureUserExists(ctx, userID); err != nil { + log.Printf("Warning: JIT Registration failed for %s: %v", userID, err) + + } + + var req ChatRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "bad request", http.StatusBadRequest) + http.Error(w, "Bad request body", http.StatusBadRequest) return } - memIf := r.Context().Value(middleware.MemoriesCtxKey) - memories := []memory.Memory{} - if memIf != nil { - if ms, ok := memIf.([]memory.Memory); ok { - memories = ms - } + + memories, err := h.Manager.Retrieve(ctx, userID, req.Message) + if err != nil { + log.Printf("Failed to retrieve memories: %v", err) + memories = []memory.Memory{} } - response := req.Message +userBio, err := h.Manager.GetUserBio(ctx, userID) +if err != nil { + log.Printf("Could not fetch user bio: %v", err) + userBio = "a user seeking assistance" +} + + +aiResponse, err := h.callLLM(ctx, req.Message, memories, userBio) +if err != nil { + log.Printf("LLM generation failed: %v", err) + http.Error(w, "Failed to generate AI response", http.StatusInternalServerError) + return +} - if err := h.Manager.Save(r.Context(), userID, req.Message); err != nil { - log.Printf("Failed to save message for user %s: %v", userID, err) - http.Error(w, "failed to save message", http.StatusInternalServerError) - return - } - log.Printf("Saved message for user %s: %s", userID, req.Message) - out := ChatResp{ - Response: response, - Memories: memories, + go func(uID, msg, aiResp string) { + bgCtx := context.Background() + _ = h.Manager.Save(bgCtx, uID, msg) + _ = h.Manager.Save(bgCtx, uID, "AI: "+aiResp) + _ = h.Manager.CheckAndSummarize(bgCtx, uID) + }(userID, req.Message, aiResponse) + + + cleanMemories := make([]string, 0, len(memories)) + for _, m := range memories { + cleanMemories = append(cleanMemories, m.Content) } + w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(out); err != nil { - http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) - return - } + json.NewEncoder(w).Encode(ChatResponse{ + Response: aiResponse, + Memories: cleanMemories, + }) } + + +type ollamaRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Stream bool `json:"stream"` +} + +type ollamaResponse struct { + Response string `json:"response"` +} + +func (h *ChatHandler) callLLM(ctx context.Context, userMessage string, memories []memory.Memory, userBio string) (string, error) { + var contextBuilder strings.Builder + for _, m := range memories { + contextBuilder.WriteString(fmt.Sprintf("- %s\n", m.Content)) + } + var bioSection string + if userBio != "" { + bioSection = fmt.Sprintf("User Info: %s. ", userBio) + } + + systemPrompt := fmt.Sprintf("You are the MemCortex Assistant. "+ + "Context: %s. "+ + "Rules: "+ + "1. Use the Context above to answer. "+ + "2. Be concise (under 3 sentences). ", bioSection) + + fullPrompt := fmt.Sprintf("%s\n\nContext:\n%s\n\nUser: %s", systemPrompt, contextBuilder.String(), userMessage) + + model := os.Getenv("LLM_MODEL") + if model == "" { + model = "deepseek-r1:1.5b" + } + + reqBody := ollamaRequest{ + Model: model, + Prompt: fullPrompt, + Stream: false, + } + + jsonData, _ := json.Marshal(reqBody) + resp, err := http.Post("http://ollama:11434/api/generate", "application/json", bytes.NewBuffer(jsonData)) + if err != nil { + return "", err + } + defer resp.Body.Close() + + var result ollamaResponse + json.NewDecoder(resp.Body).Decode(&result) + return result.Response, nil +} \ No newline at end of file diff --git a/internal/memory/manager.go b/internal/memory/manager.go index 064ba70..174c101 100644 --- a/internal/memory/manager.go +++ b/internal/memory/manager.go @@ -4,16 +4,14 @@ import ( "context" "fmt" "log" - "strings" - "time" - - weaviate "github.com/weaviate/weaviate-go-client/v4/weaviate" - "os" "strconv" + "strings" "github.com/joho/godotenv" ollama "github.com/sobowalebukola/memcortex/internal/embedder" + "github.com/sobowalebukola/memcortex/internal/summarizer" + "github.com/weaviate/weaviate-go-client/v4/weaviate" ) type Manager struct { @@ -21,6 +19,7 @@ type Manager struct { Embedder *ollama.EmbeddingClient TopK int WeaviateClient *weaviate.Client + summarizer *summarizer.Summarizer } type MemoryPrompt struct { @@ -29,7 +28,6 @@ type MemoryPrompt struct { } func NewManager(store *Store, emb *ollama.EmbeddingClient) *Manager { - err := godotenv.Load() if err != nil { log.Println("No .env file found, using system environment") @@ -43,7 +41,13 @@ func NewManager(store *Store, emb *ollama.EmbeddingClient) *Manager { if err != nil { topK = 10 } - return &Manager{Store: store, Embedder: emb, TopK: topK} + + return &Manager{ + Store: store, + Embedder: emb, + TopK: topK, + summarizer: summarizer.NewSummarizer(), + } } func (m *Manager) Retrieve(ctx context.Context, userID, query string) ([]Memory, error) { @@ -52,7 +56,13 @@ func (m *Manager) Retrieve(ctx context.Context, userID, query string) ([]Memory, if err != nil { return nil, err } - return m.Store.Search(ctx, emb, userID, m.TopK) + + emb32 := make([]float32, len(emb)) + for i, v := range emb { + emb32[i] = float32(v) + } + + return m.Store.Search(ctx, emb32, userID, m.TopK) } func (m *Manager) SaveAsync(ctx context.Context, userID, text string) { @@ -63,14 +73,15 @@ func (m *Manager) SaveAsync(ctx context.Context, userID, text string) { func (m *Manager) Save(ctx context.Context, userID, text string) error { emb, err := m.Embedder.Embed(ctx, text) - if err != nil { return err } + emb32 := make([]float32, len(emb)) for i, v := range emb { emb32[i] = float32(v) } + _, err = m.Store.Save(ctx, userID, text, emb32) return err } @@ -83,11 +94,9 @@ func FormatMemoryPrompt(memories []Memory) []MemoryPrompt { result := make([]MemoryPrompt, 0, len(memories)) for i, mem := range memories { - ts := mem.Timestamp.Format(time.RFC3339) - result = append(result, MemoryPrompt{ - Text: mem.Text, - Added: ts, + Text: mem.Content, + Added: mem.Timestamp, }) if i >= 20 { @@ -98,38 +107,94 @@ func FormatMemoryPrompt(memories []Memory) []MemoryPrompt { return result } -type MemoryManager struct { - queue *EmbeddingQueue - store *Store -} -func NewMemoryManager(queue *EmbeddingQueue, store *Store) *MemoryManager { - return &MemoryManager{ - queue: queue, - store: store, + +func (m *Manager) CheckAndSummarize(ctx context.Context, userID string) error { + if !m.isAutoSummaryEnabled() { + return nil + } + + count, err := m.Store.GetMemoryCount(ctx, userID) + if err != nil { + return fmt.Errorf("failed to get memory count: %w", err) + } + + threshold := summarizer.GetSummaryThreshold() + if count < threshold { + return nil } + + log.Printf("Memory count (%d) exceeded threshold (%d) for user %s, triggering summarization", + count, threshold, userID) + + return m.SummarizeUserMemories(ctx, userID) } -func (m *MemoryManager) SaveMemory(ctx context.Context, userID, text string) error { - go func() { - embedding, err := m.queue.Enqueue(userID, text) - if err != nil { - fmt.Println("Error generating embedding:", err) - return - } +func (m *Manager) SummarizeUserMemories(ctx context.Context, userID string) error { + batchSize := summarizer.GetSummaryBatchSize() + maxAge := summarizer.GetSummaryMaxAgeDays() - embedding32 := make([]float32, len(embedding)) + memories, err := m.Store.GetOldMemories(ctx, userID, maxAge, batchSize) + if err != nil { + return fmt.Errorf("failed to get old memories: %w", err) + } - for i, v := range embedding { - embedding32[i] = float32(v) - } - id, err := m.store.Save(ctx, userID, text, embedding32) - if err != nil { - fmt.Println("Error saving memory:", err) - return - } + if len(memories) == 0 { + log.Printf("No memories to summarize for user %s", userID) + return nil + } - log.Printf("Memory saved with ID: %s", id) - }() + contents := make([]string, len(memories)) + ids := make([]string, len(memories)) + for i, mem := range memories { + contents[i] = mem.Content + ids[i] = mem.ID + } + + summary, err := m.summarizer.SummarizeMemories(ctx, contents, userID) + if err != nil { + return fmt.Errorf("failed to generate summary: %w", err) + } + + summaryEmb, err := m.Embedder.Embed(ctx, summary) + if err != nil { + return fmt.Errorf("failed to embed summary: %w", err) + } + + emb32 := make([]float32, len(summaryEmb)) + for i, v := range summaryEmb { + emb32[i] = float32(v) + } + + log.Printf("Generated summary for %d memories (user %s)", len(memories), userID) + + if err := m.Store.SaveSummary(ctx, summary, userID, ids, emb32); err != nil { + return fmt.Errorf("failed to save summary: %w", err) + } + + if err := m.Store.DeleteMemories(ctx, ids); err != nil { + return fmt.Errorf("failed to delete original memories: %w", err) + } + + log.Printf("Successfully summarized %d memories for user %s", len(memories), userID) return nil } + +func (m *Manager) isAutoSummaryEnabled() bool { + enabled := os.Getenv("ENABLE_AUTO_SUMMARY") + if enabled == "" { + return true + } + val, _ := strconv.ParseBool(enabled) + return val +} +func (m *Manager) GetUserBio(ctx context.Context, userID string) (string, error) { + + return m.Store.GetUserBio(ctx, userID) +} + + +func (m *Manager) EnsureUserExists(ctx context.Context, userID string) error { + + return m.Store.EnsureUser(ctx, userID) +} \ No newline at end of file diff --git a/internal/memory/store.go b/internal/memory/store.go index 75fa590..61c8a90 100644 --- a/internal/memory/store.go +++ b/internal/memory/store.go @@ -4,70 +4,48 @@ import ( "context" "fmt" "time" - "log" - "os" - "strconv" "github.com/google/uuid" - "github.com/joho/godotenv" "github.com/weaviate/weaviate-go-client/v4/weaviate" "github.com/weaviate/weaviate-go-client/v4/weaviate/filters" "github.com/weaviate/weaviate-go-client/v4/weaviate/graphql" + "github.com/weaviate/weaviate/entities/models" ) type Memory struct { - ID string `json:"id"` - Text string `json:"text"` - Timestamp time.Time `json:"timestamp"` + ID string `json:"id"` + Content string `json:"content"` + Timestamp string `json:"timestamp"` + UserID string `json:"userId"` + MemoryType string `json:"memoryType"` } type Store struct { Client *weaviate.Client Class string - Dim int -} - -type SearchResult struct { - ID string - Properties map[string]interface{} } -// --------------------------- -// Initialize Weaviate store -// --------------------------- -func NewStore(class string, dim int) (*Store, error) { - client, err := weaviate.NewClient(weaviate.Config{ - Host: "weaviate:8080", - Scheme: "http", - }) - - if err != nil { - return nil, err - } - +func NewWeaviateStore(client *weaviate.Client, class string) *Store { return &Store{ Client: client, Class: class, - Dim: dim, - }, nil + } } -// --------------------------- -// Save Memory -// --------------------------- func (s *Store) Save(ctx context.Context, userID, text string, embedding []float32) (string, error) { - if len(embedding) != s.Dim { - return "", fmt.Errorf("embedding dimension mismatch") - } - id := uuid.New().String() + if userID == "" { + userID = fmt.Sprintf("user_%d", time.Now().Unix()) + log.Printf("Warning: Store received empty userID. Generated dynamic ID: %s", userID) + } - data := map[string]any{ - "user_id": userID, - "text": text, - "timestamp": time.Now().Unix(), - "embedding": embedding, + data := map[string]interface{}{ + "userId": userID, + "content": text, + "timestamp": time.Now().Format(time.RFC3339), + "memoryType": "raw", + "isSummary": false, } _, err := s.Client.Data(). @@ -78,110 +56,256 @@ func (s *Store) Save(ctx context.Context, userID, text string, embedding []float WithVector(embedding). Do(ctx) - if err != nil { - return "", err + return id, err +} + +func (s *Store) GetMemoryCount(ctx context.Context, userID string) (int, error) { + fmt.Printf(">>> [DB] Counting memories for user: %s\n", userID) + + where := filters.Where(). + WithPath([]string{"userId"}). + WithOperator(filters.Equal). + WithValueString(userID) + + resp, err := s.Client.GraphQL().Aggregate(). + WithClassName(s.Class). + WithWhere(where). + WithFields(graphql.Field{ + Name: "meta", + Fields: []graphql.Field{{Name: "count"}}, + }). + Do(ctx) + + if err != nil { return 0, err } + + data, ok := resp.Data["Aggregate"].(map[string]interface{}) + if !ok { return 0, nil } + + classData, ok := data[s.Class].([]interface{}) + if !ok || len(classData) == 0 { return 0, nil } + + fields, ok := classData[0].(map[string]interface{}) + if !ok { return 0, nil } + + meta, ok := fields["meta"].(map[string]interface{}) + if !ok { return 0, nil } + + count, ok := meta["count"].(float64) + if !ok { return 0, nil } + + fmt.Printf(">>> [DB] Database count for %s is: %d\n", userID, int(count)) + return int(count), nil +} + +func (s *Store) GetOldMemories(ctx context.Context, userID string, olderThanDays int, limit int) ([]Memory, error) { + where := filters.Where(). + WithOperator(filters.And). + WithOperands([]*filters.WhereBuilder{ + filters.Where(). + WithPath([]string{"userId"}). + WithOperator(filters.Equal). + WithValueString(userID), + filters.Where(). + WithPath([]string{"isSummary"}). + WithOperator(filters.Equal). + WithValueBoolean(false), + }) + + fields := []graphql.Field{ + {Name: "content"}, + {Name: "timestamp"}, + {Name: "memoryType"}, + {Name: "_additional", Fields: []graphql.Field{{Name: "id"}}}, } - return id, nil + resp, err := s.Client.GraphQL().Get(). + WithClassName(s.Class). + WithWhere(where). + WithLimit(limit). + WithFields(fields...). + Do(ctx) + + if err != nil { return nil, err } + return s.parseGraphQLResponse(resp) } -// --------------------------- -// Vector Search in Weaviate -// --------------------------- -func (s *Store) Search(ctx context.Context, queryEmbedding []float64, userID string, k int) ([]Memory, error) { - if len(queryEmbedding) != s.Dim { - return nil, fmt.Errorf("embedding dim mismatch: expected %d, got %d", s.Dim, len(queryEmbedding)) +func (s *Store) SaveSummary(ctx context.Context, summary string, userID string, originalIDs []string, embedding []float32) error { + data := map[string]interface{}{ + "content": summary, + "userId": userID, + "timestamp": time.Now().Format(time.RFC3339), + "memoryType": "summary", + "isSummary": true, + "originalIds": originalIDs, } - vec := float64ToFloat32Slice(queryEmbedding) + _, err := s.Client.Data(). + Creator(). + WithClassName(s.Class). + WithProperties(data). + WithVector(embedding). + Do(ctx) + + return err +} - err := godotenv.Load() - if err != nil { - log.Println("No .env file found, using system environment") +func (s *Store) DeleteMemories(ctx context.Context, ids []string) error { + for _, id := range ids { + _ = s.Client.Data().Deleter().WithClassName(s.Class).WithID(id).Do(ctx) } + return nil +} + +func (s *Store) parseGraphQLResponse(resp *models.GraphQLResponse) ([]Memory, error) { + var memories []Memory + + data, ok := resp.Data["Get"].(map[string]interface{}) + if !ok { return nil, nil } + + objects, ok := data[s.Class].([]interface{}) + if !ok { return nil, nil } + + for _, obj := range objects { + item, ok := obj.(map[string]interface{}) + if !ok { continue } + + mem := Memory{} + if content, ok := item["content"].(string); ok { + mem.Content = content + } + if ts, ok := item["timestamp"].(string); ok { + mem.Timestamp = ts + } + if mt, ok := item["memoryType"].(string); ok { + mem.MemoryType = mt + } - maxMemoryDistance, err := strconv.ParseFloat(os.Getenv("MAX_MEMORY_DISTANCE"), 64) - if err != nil || maxMemoryDistance == 0 { - maxMemoryDistance = 0.5 + if additional, ok := item["_additional"].(map[string]interface{}); ok { + if id, ok := additional["id"].(string); ok { + mem.ID = id + } + } + memories = append(memories, mem) } + return memories, nil +} - nearVector := s.Client.GraphQL().NearVectorArgBuilder(). - WithVector(vec).WithDistance(float32(maxMemoryDistance)) +func (s *Store) Search(ctx context.Context, queryEmbedding []float32, userID string, k int) ([]Memory, error) { + + fields := []graphql.Field{ + {Name: "content"}, + {Name: "timestamp"}, + {Name: "memoryType"}, + {Name: "_additional", Fields: []graphql.Field{ + {Name: "id"}, + {Name: "distance"}, + }}, + } where := filters.Where(). - WithPath([]string{"user_id"}). + WithPath([]string{"userId"}). WithOperator(filters.Equal). WithValueString(userID) - query := s.Client.GraphQL().Get(). - WithClassName("Memory_idx"). + nearVector := s.Client.GraphQL().NearVectorArgBuilder(). + WithVector(queryEmbedding) + + resp, err := s.Client.GraphQL().Get(). + WithClassName(s.Class). WithWhere(where). WithNearVector(nearVector). WithLimit(k). - WithFields( - graphql.Field{Name: "text"}, - graphql.Field{Name: "timestamp"}, - graphql.Field{ - Name: "_additional", - Fields: []graphql.Field{ - {Name: "id"}, - {Name: "distance"}, - }, - }, - ) - resp, err := query.Do(ctx) + WithFields(fields...). + Do(ctx) + + if err != nil { return nil, err } + return s.parseGraphQLResponse(resp) +} + + +func (s *Store) GetUserBio(ctx context.Context, userID string) (string, error) { + + where := filters.Where(). + WithPath([]string{"userId"}). + WithOperator(filters.Equal). + WithValueString(userID) + + result, err := s.Client.GraphQL().Get(). + WithClassName("User") . + WithFields(graphql.Field{Name: "bio"}). + WithWhere(where). + Do(ctx) if err != nil { - return nil, fmt.Errorf("graphql error: %w", err) + return "", fmt.Errorf("failed to fetch user bio: %w", err) } - if resp.Errors != nil { - return nil, fmt.Errorf("weaviate error: %v", resp.Errors) + + if result.Data == nil || result.Data["Get"] == nil { + return "", fmt.Errorf("no data found in Weaviate") } - getNode, ok := resp.Data["Get"].(map[string]interface{}) - if !ok { - return nil, nil + getMap := result.Data["Get"].(map[string]interface{}) + users, ok := getMap["User"].([]interface{}) + + + if !ok || len(users) == 0 { + return "A software project called MemCortex focusing on long-term AI memory.", nil } - raw, ok := getNode["Memory_idx"].([]interface{}) - if !ok { - return nil, nil - } + userFields := users[0].(map[string]interface{}) + bio, _ := userFields["bio"].(string) - results := make([]Memory, 0, len(raw)) + return bio, nil +} - for _, item := range raw { - obj, ok := item.(map[string]interface{}) - if !ok { - continue - } - mem := Memory{} - if v, ok := obj["text"].(string); ok { - mem.Text = v - } +func (s *Store) EnsureUser(ctx context.Context, userID string) error { + + where := filters.Where(). + WithPath([]string{"userId"}). + WithOperator(filters.Equal). + WithValueString(userID) - if ts, ok := obj["timestamp"].(float64); ok { - mem.Timestamp = time.Unix(int64(ts), 0) - } + + result, err := s.Client.GraphQL().Get(). + WithClassName("User"). + WithFields(graphql.Field{Name: "userId"}). + WithWhere(where). + Do(ctx) - if add, ok := obj["_additional"].(map[string]interface{}); ok { - if id, ok := add["id"].(string); ok { - mem.ID = id - } + if err != nil { + return fmt.Errorf("failed to check user existence: %w", err) + } + + + if result.Data != nil && result.Data["Get"] != nil { + getMap := result.Data["Get"].(map[string]interface{}) + users, ok := getMap["User"].([]interface{}) + if ok && len(users) > 0 { + + return nil } + } - results = append(results, mem) + + log.Printf(">>> [DB] New user detected: %s. Performing JIT registration...", userID) + + properties := map[string]interface{}{ + "userId": userID, + "username": "User_" + userID, + "bio": "A new user of the MemCortex system.", } - return results, nil -} -func float64ToFloat32Slice(f []float64) []float32 { - out := make([]float32, len(f)) - for i, v := range f { - out[i] = float32(v) + _, err = s.Client.Data().Creator(). + WithClassName("User"). + WithProperties(properties). + Do(ctx) + + if err != nil { + return fmt.Errorf("failed to create new user: %w", err) } - return out -} + + return nil +} \ No newline at end of file diff --git a/internal/summarizer/summarizer.go b/internal/summarizer/summarizer.go new file mode 100644 index 0000000..0e905b4 --- /dev/null +++ b/internal/summarizer/summarizer.go @@ -0,0 +1,124 @@ +package summarizer + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strconv" + "time" +) + +type Summarizer struct { + ollamaHost string + model string + client *http.Client +} + +type OllamaRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Stream bool `json:"stream"` +} + +type OllamaResponse struct { + Response string `json:"response"` + Done bool `json:"done"` +} + +func NewSummarizer() *Summarizer { + return &Summarizer{ + ollamaHost: getEnv("OLLAMA_HOST", "http://ollama:11434"), + model: getEnv("SUMMARIZATION_MODEL", "deepseek-r1:1.5b"), + client: &http.Client{ + Timeout: 5 * time.Minute, + }, + } +} + +func getEnv(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +} + +func (s *Summarizer) SummarizeMemories(ctx context.Context, memories []string, userID string) (string, error) { + if len(memories) == 0 { + return "", fmt.Errorf("no memories to summarize") + } + + // This is the line that will finally show up in your terminal + fmt.Printf("\n[SUMMARIZER] Triggered for user: %s | Batch size: %d\n", userID, len(memories)) + + prompt := s.buildSummarizationPrompt(memories, userID) + + reqBody := OllamaRequest{ + Model: s.model, + Prompt: prompt, + Stream: false, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", + fmt.Sprintf("%s/api/generate", s.ollamaHost), + bytes.NewBuffer(jsonData)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + fmt.Println("[SUMMARIZER] Calling Ollama (DeepSeek)...") + resp, err := s.client.Do(req) + if err != nil { + return "", fmt.Errorf("failed to call ollama: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("ollama returned status %d: %s", resp.StatusCode, string(body)) + } + + var ollamaResp OllamaResponse + if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil { + return "", fmt.Errorf("failed to decode response: %w", err) + } + + fmt.Println("[SUMMARIZER] Successfully generated summary.") + return ollamaResp.Response, nil +} + +func (s *Summarizer) buildSummarizationPrompt(memories []string, userID string) string { + memoriesText := "" + for _, mem := range memories { + memoriesText += fmt.Sprintf("- %s\n", mem) + } + + return fmt.Sprintf(`Summarize these memories for user "%s" into a single concise paragraph. +Do not include tags. Just provide the raw summary. +Memories: +%s`, userID, memoriesText) +} + +func GetSummaryThreshold() int { + val, _ := strconv.Atoi(getEnv("SUMMARY_THRESHOLD", "2")) + return val +} + +func GetSummaryBatchSize() int { + val, _ := strconv.Atoi(getEnv("SUMMARY_BATCH_SIZE", "2")) + return val +} + +func GetSummaryMaxAgeDays() int { + val, _ := strconv.Atoi(getEnv("SUMMARY_MAX_AGE_DAYS", "0")) + return val +} \ No newline at end of file