Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ class AiResourceRestTests {
RestAssuredMockMvc.mockMvc(mockMvc)
}

private fun getTestAdminToken(): String {
return tokenManager.generateAccessToken(
UserId(1L),
setOf(Role.ADMIN)
).value
}

private fun getTestUserToken(): String {
return tokenManager.generateAccessToken(
UserId(1L),
Expand All @@ -96,6 +103,20 @@ class AiResourceRestTests {
fun contextLoads() {
}

@Nested
inner class `권한 테스트` {
@Test
fun `일반 USER 권한으로 AI API 접근 시 403을 반환한다`() {
RestAssuredMockMvc.given()
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestUserToken())
.contentType(MediaType.APPLICATION_JSON_VALUE)
.body(ChatRequest(sessionId = 1L, message = "안녕"))
.post("/api/ai/chat")
.then()
.statusCode(403)
}
}

@Nested
inner class `채팅 API 테스트` {
@Test
Expand All @@ -119,7 +140,7 @@ class AiResourceRestTests {

// when & then
RestAssuredMockMvc.given()
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestUserToken())
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestAdminToken())
.contentType(MediaType.APPLICATION_JSON_VALUE)
.body(ChatRequest(sessionId = 1L, message = "안녕"))
.post("/api/ai/chat")
Expand Down Expand Up @@ -150,7 +171,7 @@ class AiResourceRestTests {

// when & then
RestAssuredMockMvc.given()
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestUserToken())
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestAdminToken())
.contentType(MediaType.APPLICATION_JSON_VALUE)
.body(CreateSessionRequest(title = "새 세션"))
.post("/api/ai/sessions")
Expand Down Expand Up @@ -187,7 +208,7 @@ class AiResourceRestTests {

// when & then
RestAssuredMockMvc.given()
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestUserToken())
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestAdminToken())
.get("/api/ai/sessions")
.then()
.statusCode(200)
Expand Down Expand Up @@ -222,7 +243,7 @@ class AiResourceRestTests {

// when & then
RestAssuredMockMvc.given()
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestUserToken())
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestAdminToken())
.get("/api/ai/sessions/1/messages")
.then()
.statusCode(200)
Expand Down Expand Up @@ -254,7 +275,7 @@ class AiResourceRestTests {

// when & then
RestAssuredMockMvc.given()
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestUserToken())
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestAdminToken())
.queryParam("query", "테스트")
.get("/api/ai/recommend")
.then()
Expand All @@ -279,7 +300,7 @@ class AiResourceRestTests {

// when & then
RestAssuredMockMvc.given()
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestUserToken())
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestAdminToken())
.post("/api/ai/sync")
.then()
.statusCode(200)
Expand All @@ -303,7 +324,7 @@ class AiResourceRestTests {

// when & then
RestAssuredMockMvc.given()
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestUserToken())
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestAdminToken())
.delete("/api/ai/sessions/1")
.then()
.statusCode(204)
Expand All @@ -323,7 +344,7 @@ class AiResourceRestTests {

// when & then
RestAssuredMockMvc.given()
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestUserToken())
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestAdminToken())
.queryParam("cursor", 10)
.queryParam("size", 20)
.get("/api/ai/sessions/1/messages")
Expand All @@ -338,7 +359,7 @@ class AiResourceRestTests {

// when & then
RestAssuredMockMvc.given()
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestUserToken())
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestAdminToken())
.queryParam("memoId", 5)
.queryParam("topK", 3)
.get("/api/ai/recommend")
Expand All @@ -357,7 +378,7 @@ class AiResourceRestTests {

// when & then
RestAssuredMockMvc.given()
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestUserToken())
.cookies(PreAuthFilter.ACCESS_TOKEN_HEADER, getTestAdminToken())
.contentType(MediaType.APPLICATION_JSON_VALUE)
.body(ChatRequest(sessionId = 1L, message = "메모 생성해줘"))
.post("/api/ai/chat")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,16 @@
package kr.co.jiniaslog.ai.adapter.inbound.http

import io.swagger.v3.oas.models.OpenAPI
import io.swagger.v3.oas.models.info.Info
import io.swagger.v3.oas.models.tags.Tag
import org.springdoc.core.models.GroupedOpenApi
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration

@Configuration
class AiSwaggerConfig {

@Bean
fun aiOpenApi(): OpenAPI {
return OpenAPI()
.info(
Info()
.title("AI Second Brain API")
.description("AI 기반 챗봇 및 메모 추천 API")
.version("1.0.0")
)
.addTagsItem(Tag().name("Chat").description("AI 챗봇 관련 API"))
.addTagsItem(Tag().name("Recommend").description("메모 추천 API"))
.addTagsItem(Tag().name("Sync").description("임베딩 동기화 API"))
fun aiApi(): GroupedOpenApi {
return GroupedOpenApi.builder()
.group("AI Second Brain")
.packagesToScan("kr.co.jiniaslog.ai.adapter.inbound.http")
.build()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class AsyncConfig {
private val counter = AtomicInteger(0)
override fun newThread(r: Runnable): Thread {
return Thread(r, "embedding-scheduler-${counter.incrementAndGet()}").apply {
isDaemon = false
isDaemon = true
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package kr.co.jiniaslog.ai.adapter.inbound.message

import jakarta.annotation.PreDestroy
import kr.co.jiniaslog.ai.usecase.IDeleteMemoEmbedding
import kr.co.jiniaslog.ai.usecase.ISyncMemoToEmbedding
import kr.co.jiniaslog.memo.domain.memo.MemoCreatedEvent
Expand All @@ -26,13 +27,21 @@ class MemoEventListener(
private val scheduler: ScheduledExecutorService,
) {
companion object {
private const val DEBOUNCE_DELAY_SECONDS = 5L
private const val DEBOUNCE_DELAY_MINUTES = 10L
}

// 메모별 디바운스 타이머
private val pendingUpdates = ConcurrentHashMap<Long, ScheduledFuture<*>>()
private val pendingEvents = ConcurrentHashMap<Long, MemoUpdatedEvent>()

@PreDestroy
fun shutdown() {
pendingUpdates.values.forEach { it.cancel(false) }
pendingUpdates.clear()
pendingEvents.clear()
scheduler.shutdownNow()
}

@Async("aiEmbeddingExecutor")
@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
fun handleMemoCreated(event: MemoCreatedEvent) {
Expand Down Expand Up @@ -80,7 +89,7 @@ class MemoEventListener(
logger.error(e) { "Failed to update memo ${evt.memoId} in embedding store" }
}
}
}, DEBOUNCE_DELAY_SECONDS, TimeUnit.SECONDS)
}, DEBOUNCE_DELAY_MINUTES, TimeUnit.MINUTES)

pendingUpdates[event.memoId] = future
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,21 @@ class ChromaEmbeddingStoreAdapter(
private const val MEMO_ID_KEY = "memoId"
private const val AUTHOR_ID_KEY = "authorId"
private const val TITLE_KEY = "title"
private const val CHUNK_SIZE = 1500
private const val CHUNK_OVERLAP = 200
private const val MAX_CHUNK_COUNT = 50
}

override fun store(document: MemoEmbeddingDocument) {
val doc = Document.builder()
.id("${document.memoId}")
.text("${document.title}\n\n${document.content}")
.metadata(
mapOf(
MEMO_ID_KEY to document.memoId.toString(),
AUTHOR_ID_KEY to document.authorId.toString(),
TITLE_KEY to document.title,
)
)
.build()
vectorStore.add(listOf(doc))
}
delete(document.memoId)

override fun storeAll(documents: List<MemoEmbeddingDocument>) {
val docs = documents.map { document ->
val fullText = "${document.title}\n\n${document.content}"
val chunks = chunkText(fullText)

val docs = chunks.mapIndexed { index, chunk ->
Document.builder()
.id("${document.memoId}")
.text("${document.title}\n\n${document.content}")
.id("${document.memoId}_$index")
.text(chunk)
.metadata(
mapOf(
MEMO_ID_KEY to document.memoId.toString(),
Expand All @@ -51,26 +44,76 @@ class ChromaEmbeddingStoreAdapter(
vectorStore.add(docs)
}

override fun storeAll(documents: List<MemoEmbeddingDocument>) {
val docs = documents.flatMap { document ->
val fullText = "${document.title}\n\n${document.content}"
val chunks = chunkText(fullText)

chunks.mapIndexed { index, chunk ->
Document.builder()
.id("${document.memoId}_$index")
.text(chunk)
.metadata(
mapOf(
MEMO_ID_KEY to document.memoId.toString(),
AUTHOR_ID_KEY to document.authorId.toString(),
TITLE_KEY to document.title,
)
)
.build()
}
}
vectorStore.add(docs)
}

override fun delete(memoId: Long) {
vectorStore.delete(listOf("$memoId"))
val idsToDelete = mutableListOf("$memoId")
for (index in 0 until MAX_CHUNK_COUNT) {
idsToDelete.add("${memoId}_$index")
}
vectorStore.delete(idsToDelete)
}

override fun searchSimilar(query: String, authorId: Long, topK: Int): List<SimilarMemo> {
val request = SearchRequest.builder()
.query(query)
.topK(topK * 2)
.topK(topK * 3)
.filterExpression("$AUTHOR_ID_KEY == '$authorId'")
.build()

return vectorStore.similaritySearch(request)
.take(topK)
.map { doc: Document ->
.groupBy { doc -> (doc.metadata[MEMO_ID_KEY] as String).toLong() }
.mapNotNull { (memoId, chunks) ->
// Pick the chunk with highest similarity
val bestChunk = chunks.maxByOrNull { it.score ?: 0.0 } ?: return@mapNotNull null
SimilarMemo(
memoId = (doc.metadata[MEMO_ID_KEY] as String).toLong(),
title = doc.metadata[TITLE_KEY] as String,
content = doc.text ?: "",
similarity = doc.score ?: 0.0,
memoId = memoId,
title = bestChunk.metadata[TITLE_KEY] as String,
content = bestChunk.text ?: "",
similarity = bestChunk.score ?: 0.0,
)
}
.sortedByDescending { it.similarity }
.take(topK)
}

private fun chunkText(text: String): List<String> {
if (text.length <= CHUNK_SIZE) {
return listOf(text)
}

val chunks = mutableListOf<String>()
var startIndex = 0

while (startIndex < text.length) {
val endIndex = minOf(startIndex + CHUNK_SIZE, text.length)
chunks.add(text.substring(startIndex, endIndex))

if (endIndex >= text.length) break

startIndex = endIndex - CHUNK_OVERLAP
}

return chunks
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class SecurityConfig(
it.requestMatchers("/api/v1/media").authenticated()
it.requestMatchers("/api/v1/memos/**", "/api/v1/memos").authenticated()
it.requestMatchers("/api/v1/folders/**", "/api/v1/folders").authenticated()
it.requestMatchers("/api/ai/**").authenticated()
it.requestMatchers("/api/ai/**").hasRole("ADMIN")
it.anyRequest().permitAll()
}
.headers { it.frameOptions(Customizer { it.disable() }) }
Expand Down