From 9dfe62a2bbd0b61bb8d25ca58b9730d5bbf392cd Mon Sep 17 00:00:00 2001 From: Thiraput01 Date: Mon, 28 Apr 2025 20:47:23 +0700 Subject: [PATCH 1/2] feat/agentic --- src/cp_genie/api/v1/router.py | 3 +- src/cp_genie/domain/rag/agentic.py | 122 ++++++++++++++++++----------- src/cp_genie/domain/rag/normal.py | 29 +++---- 3 files changed, 95 insertions(+), 59 deletions(-) diff --git a/src/cp_genie/api/v1/router.py b/src/cp_genie/api/v1/router.py index 8d6a38e..475d38c 100644 --- a/src/cp_genie/api/v1/router.py +++ b/src/cp_genie/api/v1/router.py @@ -2,6 +2,7 @@ from fastapi import Request from cp_genie.domain.rag.normal import NormalRAG +from cp_genie.domain.rag.agentic import AgenticRAG from cp_genie.api.v1.schema import ChatRequest, ChatResponse from cp_genie.infrastructure.chat_memory import get_by_session_id @@ -9,7 +10,7 @@ RAG_CLASSES = { "normal": NormalRAG, - # "agentic": AgenticRAG, + "agentic": AgenticRAG, } diff --git a/src/cp_genie/domain/rag/agentic.py b/src/cp_genie/domain/rag/agentic.py index ced75c1..f0fde77 100644 --- a/src/cp_genie/domain/rag/agentic.py +++ b/src/cp_genie/domain/rag/agentic.py @@ -1,75 +1,109 @@ from langgraph.graph import StateGraph, END -from langchain_core.runnables import RunnableLambda from langchain.prompts import ChatPromptTemplate from langchain.chains.combine_documents import create_stuff_documents_chain from cp_genie.domain.rag.base import State -from langchain.agents import Tool, AgentExecutor, create_tool_calling_agent +from langchain_core.messages import HumanMessage, ToolMessage +from langchain_core.tools import tool +from langgraph.prebuilt import ToolNode, tools_condition +from typing import List +from langchain_core.documents import Document class AgenticRAG: - def __init__(self, llm, tools, memory): + def __init__(self, llm, retriever, memory): self.llm = llm - self.tools = tools # list of Tool objects + self.retriever = retriever self.memory = memory self.chain = self._build_graph() def _build_graph(self) -> StateGraph: - def query(state) -> State: - self.memory.add_user_message(state["question"]) - llm_with_tools = self.llm.bind_tools(self.tools) - response = llm_with_tools.invoke(state["messages"]) - return {"messages": state["messages"] + [response]} - - def retrieve(state) -> State: - tool_inputs = {"input": state["question"]} - tool_result = agent_executor.invoke(tool_inputs) - return {**state, "context": tool_result.get("output", [])} - - def generate(state) -> State: - result = combine_chain.invoke( - { - "messages": state["messages"], - "context": state["context"], - "question": state["question"], - } - ) - self.memory.add_user_message(state["question"]) - self.memory.add_ai_message(result) - return {**state, "output": result} + @tool + def retrieve(query: str) -> List[Document]: + """ + Retrieves information related to the input query + from a vector database containing information + on Computer Engineering at Chulalongkorn University. + Use this tool ONLY when the user asks a question that requires specific + knowledge about the university or department. Do not use for general conversation. + """ + docs = self.retriever.invoke(query) + return docs + + tools = [retrieve] + tool_node = ToolNode(tools) + llm_with_tools = self.llm.bind_tools(tools) prompt = ChatPromptTemplate.from_messages( [ - ( - "system", - "Use the available tools to retrieve useful context and answer as concisely as possible.", - ), + ("system", "Answer as concisely as possible."), ( "human", - "chat history: {messages}\nretrieved context: {context}\nquestion: {question}", + "chat history: {messages}\nretrieved context: {context}\n", ), ] ) combine_chain = create_stuff_documents_chain(self.llm, prompt) - # Set up an agent to choose tools - agent = create_tool_calling_agent(self.llm, self.tools) - agent_executor = AgentExecutor(agent=agent, tools=self.tools, verbose=True) + def query_or_respond(state: State) -> dict: + print("--- Agent Node: Query or Respond ---") + messages = state["messages"] + response = llm_with_tools.invoke(messages) + return {"messages": [response]} + + def generate(state: State) -> dict: + print("--- Generate Node ---") + messages = state["messages"] + last_message = messages[-1] + + if not isinstance(last_message, ToolMessage): + raise ValueError( + "Last message is not a ToolMessage. Generation node expects tool output." + ) + + retrieved_docs = last_message.content + if not isinstance(retrieved_docs, list) or not all( + isinstance(doc, Document) for doc in retrieved_docs + ): + if isinstance(retrieved_docs, str): + retrieved_docs = [Document(page_content=retrieved_docs)] + else: + retrieved_docs = [Document(page_content=str(retrieved_docs))] + + generation = combine_chain.invoke( + { + "messages": messages, + "context": retrieved_docs, + } + ) + + self.memory.add_ai_message(generation) + return {"messages": [generation]} graph = StateGraph(State) - graph.add_node("retrieve", RunnableLambda(retrieve)) - graph.add_node("generate", RunnableLambda(generate)) + graph.add_node("agent", query_or_respond) + graph.add_node("tools", tool_node) + graph.add_node("generate", generate) + + graph.set_entry_point("agent") + + graph.add_conditional_edges( + "agent", + tools_condition, + { + "tools": "tools", + END: END, + }, + ) - graph.set_entry_point("retrieve") - graph.add_edge("retrieve", "generate") + graph.add_edge("tools", "generate") graph.add_edge("generate", END) return graph.compile() - def invoke(self, input: dict) -> State: - state: State = { - "messages": self.memory.messages, - "question": input["input"], + def invoke(self, input) -> State: + self.memory.add_user_message(HumanMessage(content=input)) + initial_state: State = { + "messages": self.memory.get_messages(), "context": [], - "output": "", } - return self.chain.invoke(state) + return self.chain.invoke(initial_state) diff --git a/src/cp_genie/domain/rag/normal.py b/src/cp_genie/domain/rag/normal.py index bdc3b6c..938c3a5 100644 --- a/src/cp_genie/domain/rag/normal.py +++ b/src/cp_genie/domain/rag/normal.py @@ -14,6 +14,18 @@ def __init__(self, llm, retriever, memory): self.chain = self._build_graph() def _build_graph(self) -> StateGraph: + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "Answer as concisely as possible."), + ( + "human", + "chat history: {messages}\nretrieved context: {context}\n", + ), + ] + ) + combine_chain = create_stuff_documents_chain(self.llm, prompt) + def retrieve(state) -> State: last_message = self.memory.get_lastest_message().content docs = self.retriever.invoke(last_message) @@ -23,24 +35,13 @@ def generate(state) -> State: result = combine_chain.invoke( { "messages": state["messages"], - "context": state["context"], + "context": state.get("context", []), } ) self.memory.add_ai_message(result) - updated_messages = self.memory.get_messages() return {**state, "messages": updated_messages} - prompt = ChatPromptTemplate.from_messages( - [ - ("system", "Answer as concisely as possible."), - ( - "human", - "chat history: {messages}\nretrieved context: {context}\n", - ), - ] - ) - combine_chain = create_stuff_documents_chain(self.llm, prompt) graph = StateGraph(State) graph.add_node("retrieve", RunnableLambda(retrieve)) graph.add_node("generate", RunnableLambda(generate)) @@ -53,8 +54,8 @@ def generate(state) -> State: def invoke(self, input) -> State: self.memory.add_user_message(HumanMessage(content=input)) - state: State = { + initial_state: State = { "messages": self.memory.get_messages(), "context": [], } - return self.chain.invoke(state) + return self.chain.invoke(initial_state) From 9b1181c74cf4a70c1677961c5153cf96a9c2d738 Mon Sep 17 00:00:00 2001 From: Thiraput01 Date: Mon, 28 Apr 2025 21:05:51 +0700 Subject: [PATCH 2/2] fix: workflows --- .github/workflows/check.yml | 13 +++-- poetry.lock | 98 +++++++++++++++++++++++++++++++++---- pyproject.toml | 27 ++++++++++ 3 files changed, 125 insertions(+), 13 deletions(-) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index afd827c..b26a8ae 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -13,7 +13,12 @@ jobs: - uses: actions/checkout@v4 - uses: ./.github/actions/setup - run: poetry install --with checks - - run: poetry run invoke checks.format - - run: poetry run invoke checks.code - - run: poetry run invoke checks.type - - run: poetry run invoke checks.security + + - name: Run Ruff + run: poetry run ruff check . + + - name: Run Mypy + run: poetry run mypy src/ + + - name: Run Tests + run: poetry run pytest diff --git a/poetry.lock b/poetry.lock index b720a36..860bdfc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -647,12 +647,12 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" -groups = ["main", "dev"] -markers = "platform_system == \"Windows\" or sys_platform == \"win32\"" +groups = ["main", "checks", "dev"] files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +markers = {main = "platform_system == \"Windows\" or sys_platform == \"win32\"", checks = "sys_platform == \"win32\"", dev = "platform_system == \"Windows\" or sys_platform == \"win32\""} [[package]] name = "coloredlogs" @@ -1607,7 +1607,7 @@ version = "2.1.0" description = "brain-dead simple config-ini parsing" optional = false python-versions = ">=3.8" -groups = ["dev"] +groups = ["checks", "dev"] files = [ {file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"}, {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"}, @@ -2868,13 +2868,66 @@ files = [ [package.dependencies] dill = ">=0.3.7" +[[package]] +name = "mypy" +version = "1.15.0" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.9" +groups = ["checks"] +files = [ + {file = "mypy-1.15.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:979e4e1a006511dacf628e36fadfecbcc0160a8af6ca7dad2f5025529e082c13"}, + {file = "mypy-1.15.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c4bb0e1bd29f7d34efcccd71cf733580191e9a264a2202b0239da95984c5b559"}, + {file = "mypy-1.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:be68172e9fd9ad8fb876c6389f16d1c1b5f100ffa779f77b1fb2176fcc9ab95b"}, + {file = "mypy-1.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c7be1e46525adfa0d97681432ee9fcd61a3964c2446795714699a998d193f1a3"}, + {file = "mypy-1.15.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2e2c2e6d3593f6451b18588848e66260ff62ccca522dd231cd4dd59b0160668b"}, + {file = "mypy-1.15.0-cp310-cp310-win_amd64.whl", hash = "sha256:6983aae8b2f653e098edb77f893f7b6aca69f6cffb19b2cc7443f23cce5f4828"}, + {file = "mypy-1.15.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2922d42e16d6de288022e5ca321cd0618b238cfc5570e0263e5ba0a77dbef56f"}, + {file = "mypy-1.15.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2ee2d57e01a7c35de00f4634ba1bbf015185b219e4dc5909e281016df43f5ee5"}, + {file = "mypy-1.15.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:973500e0774b85d9689715feeffcc980193086551110fd678ebe1f4342fb7c5e"}, + {file = "mypy-1.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5a95fb17c13e29d2d5195869262f8125dfdb5c134dc8d9a9d0aecf7525b10c2c"}, + {file = "mypy-1.15.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1905f494bfd7d85a23a88c5d97840888a7bd516545fc5aaedff0267e0bb54e2f"}, + {file = "mypy-1.15.0-cp311-cp311-win_amd64.whl", hash = "sha256:c9817fa23833ff189db061e6d2eff49b2f3b6ed9856b4a0a73046e41932d744f"}, + {file = "mypy-1.15.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:aea39e0583d05124836ea645f412e88a5c7d0fd77a6d694b60d9b6b2d9f184fd"}, + {file = "mypy-1.15.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2f2147ab812b75e5b5499b01ade1f4a81489a147c01585cda36019102538615f"}, + {file = "mypy-1.15.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ce436f4c6d218a070048ed6a44c0bbb10cd2cc5e272b29e7845f6a2f57ee4464"}, + {file = "mypy-1.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8023ff13985661b50a5928fc7a5ca15f3d1affb41e5f0a9952cb68ef090b31ee"}, + {file = "mypy-1.15.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1124a18bc11a6a62887e3e137f37f53fbae476dc36c185d549d4f837a2a6a14e"}, + {file = "mypy-1.15.0-cp312-cp312-win_amd64.whl", hash = "sha256:171a9ca9a40cd1843abeca0e405bc1940cd9b305eaeea2dda769ba096932bb22"}, + {file = "mypy-1.15.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:93faf3fdb04768d44bf28693293f3904bbb555d076b781ad2530214ee53e3445"}, + {file = "mypy-1.15.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:811aeccadfb730024c5d3e326b2fbe9249bb7413553f15499a4050f7c30e801d"}, + {file = "mypy-1.15.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:98b7b9b9aedb65fe628c62a6dc57f6d5088ef2dfca37903a7d9ee374d03acca5"}, + {file = "mypy-1.15.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c43a7682e24b4f576d93072216bf56eeff70d9140241f9edec0c104d0c515036"}, + {file = "mypy-1.15.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:baefc32840a9f00babd83251560e0ae1573e2f9d1b067719479bfb0e987c6357"}, + {file = "mypy-1.15.0-cp313-cp313-win_amd64.whl", hash = "sha256:b9378e2c00146c44793c98b8d5a61039a048e31f429fb0eb546d93f4b000bedf"}, + {file = "mypy-1.15.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e601a7fa172c2131bff456bb3ee08a88360760d0d2f8cbd7a75a65497e2df078"}, + {file = "mypy-1.15.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:712e962a6357634fef20412699a3655c610110e01cdaa6180acec7fc9f8513ba"}, + {file = "mypy-1.15.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f95579473af29ab73a10bada2f9722856792a36ec5af5399b653aa28360290a5"}, + {file = "mypy-1.15.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8f8722560a14cde92fdb1e31597760dc35f9f5524cce17836c0d22841830fd5b"}, + {file = "mypy-1.15.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1fbb8da62dc352133d7d7ca90ed2fb0e9d42bb1a32724c287d3c76c58cbaa9c2"}, + {file = "mypy-1.15.0-cp39-cp39-win_amd64.whl", hash = "sha256:d10d994b41fb3497719bbf866f227b3489048ea4bbbb5015357db306249f7980"}, + {file = "mypy-1.15.0-py3-none-any.whl", hash = "sha256:5469affef548bd1895d86d3bf10ce2b44e33d86923c29e4d675b3e323437ea3e"}, + {file = "mypy-1.15.0.tar.gz", hash = "sha256:404534629d51d3efea5c800ee7c42b72a6554d6c400e6a79eafe15d11341fd43"}, +] + +[package.dependencies] +mypy_extensions = ">=1.0.0" +typing_extensions = ">=4.6.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +faster-cache = ["orjson"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + [[package]] name = "mypy-extensions" version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." optional = false python-versions = ">=3.5" -groups = ["main", "dev"] +groups = ["main", "checks", "dev"] files = [ {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, @@ -3458,7 +3511,7 @@ version = "24.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" -groups = ["main", "dev"] +groups = ["main", "checks", "dev"] files = [ {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, @@ -3729,7 +3782,7 @@ version = "1.5.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" -groups = ["dev"] +groups = ["checks", "dev"] files = [ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, @@ -4350,7 +4403,7 @@ version = "8.3.5" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" -groups = ["dev"] +groups = ["checks", "dev"] files = [ {file = "pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820"}, {file = "pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845"}, @@ -5057,6 +5110,33 @@ files = [ [package.dependencies] pyasn1 = ">=0.1.3" +[[package]] +name = "ruff" +version = "0.4.10" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +groups = ["checks"] +files = [ + {file = "ruff-0.4.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5c2c4d0859305ac5a16310eec40e4e9a9dec5dcdfbe92697acd99624e8638dac"}, + {file = "ruff-0.4.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:a79489607d1495685cdd911a323a35871abfb7a95d4f98fc6f85e799227ac46e"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1dd1681dfa90a41b8376a61af05cc4dc5ff32c8f14f5fe20dba9ff5deb80cd6"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c75c53bb79d71310dc79fb69eb4902fba804a81f374bc86a9b117a8d077a1784"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18238c80ee3d9100d3535d8eb15a59c4a0753b45cc55f8bf38f38d6a597b9739"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d8f71885bce242da344989cae08e263de29752f094233f932d4f5cfb4ef36a81"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:330421543bd3222cdfec481e8ff3460e8702ed1e58b494cf9d9e4bf90db52b9d"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e9b6fb3a37b772628415b00c4fc892f97954275394ed611056a4b8a2631365e"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f54c481b39a762d48f64d97351048e842861c6662d63ec599f67d515cb417f6"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:67fe086b433b965c22de0b4259ddfe6fa541c95bf418499bedb9ad5fb8d1c631"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:acfaaab59543382085f9eb51f8e87bac26bf96b164839955f244d07125a982ef"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3cea07079962b2941244191569cf3a05541477286f5cafea638cd3aa94b56815"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:338a64ef0748f8c3a80d7f05785930f7965d71ca260904a9321d13be24b79695"}, + {file = "ruff-0.4.10-py3-none-win32.whl", hash = "sha256:ffe3cd2f89cb54561c62e5fa20e8f182c0a444934bf430515a4b422f1ab7b7ca"}, + {file = "ruff-0.4.10-py3-none-win_amd64.whl", hash = "sha256:67f67cef43c55ffc8cc59e8e0b97e9e60b4837c8f21e8ab5ffd5d66e196e25f7"}, + {file = "ruff-0.4.10-py3-none-win_arm64.whl", hash = "sha256:dd1fcee327c20addac7916ca4e2653fbbf2e8388d8a6477ce5b4e986b68ae6c0"}, + {file = "ruff-0.4.10.tar.gz", hash = "sha256:3aa4f2bc388a30d346c56524f7cacca85945ba124945fe489952aadb6b5cd804"}, +] + [[package]] name = "safetensors" version = "0.5.3" @@ -5814,7 +5894,7 @@ version = "4.13.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" -groups = ["main", "dev"] +groups = ["main", "checks", "dev"] files = [ {file = "typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c"}, {file = "typing_extensions-4.13.2.tar.gz", hash = "sha256:e6c81219bd689f51865d9e372991c540bda33a0379d5573cddb9a3a23f7caaef"}, @@ -6347,4 +6427,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.1" python-versions = ">=3.11,<3.13" -content-hash = "4c5a767e35e28560bd47e580962c8cf6c7a14dc363292748c4ed97a9abcdc445" +content-hash = "6367991156a879517f8fd01d9b5a64d51227f4da1c727e6e3c1ff3e8db5ec704" diff --git a/pyproject.toml b/pyproject.toml index 8b977b0..005469a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,33 @@ packages = [ { include = "cp_genie", from = "src" } ] +# Ruff +[tool.ruff] +fix = true +indent-width = 4 +line-length = 100 +target-version = "py312" + +[tool.ruff.format] +docstring-code-format = true + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.ruff.lint.per-file-ignores] +"tests/*.py" = ["D100", "D103"] + + +# Checks +[tool.poetry.group.checks] +optional = true + +[tool.poetry.group.checks.dependencies] +ruff = "^0.4.0" +mypy = "^1.8.0" +pytest = "^8.1.1" + + [tool.poetry.scripts] start = "run:main"