diff --git a/.idea/vcs.xml b/.idea/vcs.xml index 35eb1ddf..72cefee0 100644 --- a/.idea/vcs.xml +++ b/.idea/vcs.xml @@ -2,5 +2,8 @@ + + + \ No newline at end of file diff --git a/example/cgi/timeout.cgi b/example/cgi/timeout.cgi new file mode 100755 index 00000000..7017f8ae --- /dev/null +++ b/example/cgi/timeout.cgi @@ -0,0 +1,15 @@ +#!/bin/bash + +# HTTP Header +echo "Content-Type: text/html" +sleep 10 +echo "" + +# HTML Content +echo "" +echo "CGI Test" +echo "" +echo "

Hello, CGI!

" +echo "

This page was generated by a Shell script.

" +echo "" +echo "" \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 91e05b7d..ea8788b7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -18,6 +18,8 @@ add_library(webserv_lib STATIC lib/utils/logger.hpp lib/utils/string.cpp lib/utils/string.hpp + lib/utils/time.cpp + lib/utils/time.hpp lib/utils/types/option.hpp lib/utils/types/result.hpp lib/utils/types/unit.hpp diff --git a/src/lib/core/action/run_cgi_action.cpp b/src/lib/core/action/run_cgi_action.cpp index cb675768..8c8dbc2a 100644 --- a/src/lib/core/action/run_cgi_action.cpp +++ b/src/lib/core/action/run_cgi_action.cpp @@ -8,6 +8,7 @@ #include "core/handler/write_cgi_request_handler.hpp" #include "../../cgi/meta_variable.hpp" #include "utils/fd.hpp" +#include "utils/time.hpp" #include "utils/logger.hpp" void RunCgiAction::execute(ActionContext &ctx) { @@ -134,5 +135,6 @@ void RunCgiAction::parentRoutine(const ActionContext &ctx, const int socketFd, c ctx.getState().getEventNotifier().unregisterEvent(Event(clientFd_, Event::kWrite)); ctx.getState().getEventHandlerRepository().remove(clientFd_, Event::kRead); ctx.getState().getEventHandlerRepository().remove(clientFd_, Event::kWrite); - ctx.getState().getCgiProcessRepository().set(childPid, {clientFd_, socketFd}); + CgiProcessRepository::Data data = {clientFd_, socketFd, utils::Time::getCurrentTime()}; + ctx.getState().getCgiProcessRepository().set(childPid, data); } diff --git a/src/lib/core/handler/read_request_handler.cpp b/src/lib/core/handler/read_request_handler.cpp index c0d7e673..c635870d 100644 --- a/src/lib/core/handler/read_request_handler.cpp +++ b/src/lib/core/handler/read_request_handler.cpp @@ -17,6 +17,9 @@ IEventHandler::InvokeResult ReadRequestHandler::invoke(const Context &ctx) { LOG_DEBUG("start ReadRequestHandler"); const auto conn = ctx.getConnection().unwrap(); + // アクティビティを更新 + ctx.getConnection().unwrap().get().updateActivity(); + ReadBuffer &readBuf = conn.get().getReadBuffer(); const Result, error::AppError> result = reqReader_.readRequest(readBuf); if (result.isErr()) { diff --git a/src/lib/core/server.cpp b/src/lib/core/server.cpp index c2bab85c..594d0eaa 100644 --- a/src/lib/core/server.cpp +++ b/src/lib/core/server.cpp @@ -6,7 +6,10 @@ #include "http/response/response_builder.hpp" #include "transport/listener.hpp" #include "utils/logger.hpp" +#include "utils/time.hpp" #include +#include +#include Server::Server(const config::Config &config) : config_(config) { const config::ServerContextList &servers = config_.getServers(); @@ -74,8 +77,12 @@ void Server::start() { if (!result.empty()) LOG_DEBUG("child process reaped"); for (std::vector::const_iterator it = result.begin(); it != result.end(); ++it) { - if (it->status == 0) continue; const Option data = state_.getCgiProcessRepository().get(it->pid); + state_.getCgiProcessRepository().remove(it->pid); + + if (it->status == 0) continue; + + // CGI スクリプトがエラーで終わった場合 if (data.isNone()) { LOG_WARNF( "reaped child process %d with non-zero status %d, but no client fd found", @@ -109,6 +116,9 @@ void Server::start() { invokeHandlers(ctx); } + + // タイムアウトチェック + removeTimeoutHandlers(); } } @@ -210,3 +220,74 @@ void Server::executeActions(ActionContext &actionCtx, std::vector act delete action; } } + +void Server::removeTimeoutHandlers() { + removeTimeoutRequestHandlers(); + removeTimeoutCgiProcesses(); +} + +void Server::removeTimeoutRequestHandlers() { + const std::time_t currentTime = utils::Time::getCurrentTime(); + + const std::vector timedOutFds = + state_.getConnectionRepository().getTimedOutConnectionFds(currentTime, REQUEST_TIMEOUT_SECONDS, listenerFds_); + + for (std::vector::const_iterator it = timedOutFds.begin(); it != timedOutFds.end(); ++it) { + const int fd = *it; + + // Read ハンドラーが登録されているかチェック(リクエスト待ち状態) + const Option > readHandler = state_.getEventHandlerRepository().get(fd, Event::kRead); + if (readHandler.isNone()) { + continue; + } + + LOG_INFOF("Request timeout for fd %d", fd); + + // タイムアウトレスポンスを設定 + http::ResponseBuilder builder; + http::Response response = builder.status(http::kStatusRequestTimeout).build(); + + // Read ハンドラーを削除し、Write ハンドラーを設定 + state_.getEventNotifier().unregisterEvent(Event(fd, Event::kRead)); + state_.getEventHandlerRepository().remove(fd, Event::kRead); + state_.getEventNotifier().registerEvent(Event(fd, Event::kWrite)); + state_.getEventHandlerRepository().set(fd, Event::kWrite, new WriteResponseHandler(response)); + } +} + +void Server::removeTimeoutCgiProcesses() { + const std::time_t currentTime = utils::Time::getCurrentTime(); + const std::vector > timedOutProcesses = + state_.getCgiProcessRepository().getTimedOutProcesses(currentTime, CGI_TIMEOUT_SECONDS); + + for (std::vector >::const_iterator it = timedOutProcesses.begin(); + it != timedOutProcesses.end(); + ++it) { + const pid_t pid = it->first; + const CgiProcessRepository::Data &data = it->second; + + LOG_INFOF("CGI timeout for pid %d", pid); + + // CGI プロセスを強制終了 + if (kill(pid, SIGTERM) == -1) { + LOG_WARNF("Failed to terminate CGI process %d: %s", pid, std::strerror(errno)); + } + + // プロセスソケットをクリーンアップ + const int processSocketFd = data.processSocketFd; + const int clientFd = data.clientFd; + + state_.getEventNotifier().unregisterEvent(Event(processSocketFd, Event::kRead)); + state_.getEventNotifier().unregisterEvent(Event(processSocketFd, Event::kWrite)); + state_.getEventHandlerRepository().remove(processSocketFd, Event::kRead); + state_.getEventHandlerRepository().remove(processSocketFd, Event::kWrite); + state_.getConnectionRepository().remove(processSocketFd); + state_.getCgiProcessRepository().remove(pid); + + // Gateway Timeout レスポンスを返す + http::ResponseBuilder builder; + http::Response response = builder.status(http::kStatusGatewayTimeout).build(); + state_.getEventNotifier().registerEvent(Event(clientFd, Event::kWrite)); + state_.getEventHandlerRepository().set(clientFd, Event::kWrite, new WriteResponseHandler(response)); + } +} diff --git a/src/lib/core/server.hpp b/src/lib/core/server.hpp index 9b85d5ab..c8bbb604 100644 --- a/src/lib/core/server.hpp +++ b/src/lib/core/server.hpp @@ -19,6 +19,9 @@ class Server { // VirtualServer は http::Router を持っていて、コピー不可なのでポインタで持つ typedef std::vector VirtualServerList; + static const int REQUEST_TIMEOUT_SECONDS = 5; + static const int CGI_TIMEOUT_SECONDS = 5; + config::Config config_; VirtualServerList virtualServers_; @@ -32,6 +35,9 @@ class Server { static void executeActions(ActionContext &actionCtx, std::vector actions); void invokeHandlers(const Context &ctx); void invokeSingleHandler(const Context &ctx, const Ref &handler, bool shouldCallHandler); + void removeTimeoutHandlers(); + void removeTimeoutRequestHandlers(); + void removeTimeoutCgiProcesses(); }; #endif diff --git a/src/lib/core/server_state.cpp b/src/lib/core/server_state.cpp index 9897b990..61234d56 100644 --- a/src/lib/core/server_state.cpp +++ b/src/lib/core/server_state.cpp @@ -1,6 +1,9 @@ #include "server_state.hpp" #include "utils/logger.hpp" #include "utils/ref.hpp" +#include "utils/time.hpp" +#include +#include ConnectionRepository::ConnectionRepository() {} @@ -45,6 +48,30 @@ void ConnectionRepository::remove(const int fd) { LOG_DEBUGF("connection removed from server"); } +std::vector ConnectionRepository::getTimedOutConnectionFds( + std::time_t currentTime, double timeoutSeconds, const std::set &excludeFds +) const { + std::vector timedOutFds; + + for (std::map::const_iterator it = connections_.begin(); it != connections_.end(); ++it) { + const int fd = it->first; + Connection *conn = it->second; + + // 除外リストに含まれるFDはスキップ + if (excludeFds.count(fd) > 0) { + continue; + } + + // タイムアウトチェック + const double elapsed = utils::Time::diffTimeSeconds(currentTime, conn->getLastActivityTime()); + if (elapsed > timeoutSeconds) { + timedOutFds.push_back(fd); + } + } + + return timedOutFds; +} + EventHandlerRepository::EventHandlerRepository() {} EventHandlerRepository::~EventHandlerRepository() { @@ -104,6 +131,23 @@ void CgiProcessRepository::remove(const pid_t pid) { pidToData_.erase(pid); } +std::vector > +CgiProcessRepository::getTimedOutProcesses(std::time_t currentTime, double timeoutSeconds) const { + std::vector > timedOutProcesses; + + for (std::map::const_iterator it = pidToData_.begin(); it != pidToData_.end(); ++it) { + const pid_t pid = it->first; + const Data &data = it->second; + + const double elapsed = utils::Time::diffTimeSeconds(currentTime, data.startTime); + if (elapsed > timeoutSeconds) { + timedOutProcesses.push_back(std::make_pair(pid, data)); + } + } + + return timedOutProcesses; +} + ServerState::ServerState() { // self-pipe の読み端を監視対象にする reaper_.attachToEventNotifier(&getEventNotifier()); diff --git a/src/lib/core/server_state.hpp b/src/lib/core/server_state.hpp index 61b35020..49573142 100644 --- a/src/lib/core/server_state.hpp +++ b/src/lib/core/server_state.hpp @@ -7,6 +7,9 @@ #include "transport/connection.hpp" #include "utils/types/option.hpp" #include +#include +#include +#include // TODO: 共通化するべき? @@ -23,6 +26,10 @@ class ConnectionRepository : public NonCopyable { void set(int fd, Connection *conn); void remove(int fd); + // タイムアウトしたコネクションのFDを取得 + std::vector + getTimedOutConnectionFds(std::time_t currentTime, double timeoutSeconds, const std::set &excludeFds) const; + private: std::map connections_; }; @@ -48,6 +55,7 @@ class CgiProcessRepository : public NonCopyable { struct Data { int clientFd; int processSocketFd; + std::time_t startTime; }; CgiProcessRepository() {} @@ -56,6 +64,9 @@ class CgiProcessRepository : public NonCopyable { void set(pid_t pid, Data data); void remove(pid_t pid); + // タイムアウトしたプロセスを取得 + std::vector > getTimedOutProcesses(std::time_t currentTime, double timeoutSeconds) const; + private: std::map pidToData_; }; diff --git a/src/lib/event/event_notifier.cpp b/src/lib/event/event_notifier.cpp index e4991c9a..5049f301 100644 --- a/src/lib/event/event_notifier.cpp +++ b/src/lib/event/event_notifier.cpp @@ -72,9 +72,9 @@ void EpollEventNotifier::unregisterEvent(const Event &event) { } } -EpollEventNotifier::WaitEventsResult EpollEventNotifier::waitEvents() { +EpollEventNotifier::WaitEventsResult EpollEventNotifier::waitEvents(int timeoutMs) { epoll_event evs[1024]; - const int numEvents = epoll_wait(epollFd_.get(), evs, 1024, -1); + const int numEvents = epoll_wait(epollFd_.get(), evs, 1024, timeoutMs); if (numEvents == -1) { LOG_ERRORF("epoll_wait failed: %s", std::strerror(errno)); return Err(error::kUnknown); @@ -160,7 +160,7 @@ void PollEventNotifier::unregisterEvent(const Event &event) { } } -IEventNotifier::WaitEventsResult PollEventNotifier::waitEvents() { +IEventNotifier::WaitEventsResult PollEventNotifier::waitEvents(int timeoutMs) { std::vector fds; for (EventMap::const_iterator it = registeredEvents_.begin(); it != registeredEvents_.end(); ++it) { pollfd pfd = {}; @@ -169,7 +169,7 @@ IEventNotifier::WaitEventsResult PollEventNotifier::waitEvents() { fds.push_back(pfd); } - const int result = poll(fds.data(), fds.size(), -1); + const int result = poll(fds.data(), fds.size(), timeoutMs); if (result == -1) { LOG_ERRORF("poll failed: %s", std::strerror(errno)); return Err(error::kUnknown); diff --git a/src/lib/event/event_notifier.hpp b/src/lib/event/event_notifier.hpp index 25802ff1..43927add 100644 --- a/src/lib/event/event_notifier.hpp +++ b/src/lib/event/event_notifier.hpp @@ -19,7 +19,7 @@ class IEventNotifier { virtual void unregisterEvent(const Event &event) = 0; typedef Result, error::AppError> WaitEventsResult; - virtual WaitEventsResult waitEvents() = 0; + virtual WaitEventsResult waitEvents(int timeoutMs = 1000) = 0; }; // epoll の抽象 @@ -29,7 +29,7 @@ class EpollEventNotifier : public IEventNotifier { void registerEvent(const Event &event); void unregisterEvent(const Event &event); - WaitEventsResult waitEvents(); + WaitEventsResult waitEvents(int timeoutMs = 1000); private: AutoFd epollFd_; @@ -46,7 +46,7 @@ class PollEventNotifier : public IEventNotifier { void registerEvent(const Event &event); void unregisterEvent(const Event &event); - WaitEventsResult waitEvents(); + WaitEventsResult waitEvents(int timeoutMs = 1000); private: typedef std::map EventMap; diff --git a/src/lib/transport/connection.cpp b/src/lib/transport/connection.cpp index ee166444..9e82cbc4 100644 --- a/src/lib/transport/connection.cpp +++ b/src/lib/transport/connection.cpp @@ -1,9 +1,10 @@ #include "connection.hpp" #include "utils/logger.hpp" +#include "utils/time.hpp" Connection::Connection(const int fd, const Address &localAddress, const Address &foreignAddress) : clientFd_(fd), localAddress_(localAddress), foreignAddress_(foreignAddress), fdReader_(clientFd_), - buffer_(fdReader_) {} + buffer_(fdReader_), lastActivityTime_(utils::Time::getCurrentTime()) {} Connection::~Connection() { LOG_DEBUG("Connection: destruct"); @@ -24,3 +25,11 @@ const Address &Connection::getForeignAddress() const { ReadBuffer &Connection::getReadBuffer() { return buffer_; } + +std::time_t Connection::getLastActivityTime() const { + return lastActivityTime_; +} + +void Connection::updateActivity() { + lastActivityTime_ = utils::Time::getCurrentTime(); +} diff --git a/src/lib/transport/connection.hpp b/src/lib/transport/connection.hpp index 363ced85..a0067928 100644 --- a/src/lib/transport/connection.hpp +++ b/src/lib/transport/connection.hpp @@ -5,6 +5,7 @@ #include "utils/auto_fd.hpp" #include "utils/io/read_buffer.hpp" #include "utils/io/reader.hpp" +#include // クライアントソケットの抽象 class Connection { @@ -16,6 +17,8 @@ class Connection { const Address &getLocalAddress() const; const Address &getForeignAddress() const; ReadBuffer &getReadBuffer(); + std::time_t getLastActivityTime() const; + void updateActivity(); private: AutoFd clientFd_; @@ -23,6 +26,7 @@ class Connection { Address foreignAddress_; io::FdReader fdReader_; // ReadBuffer に渡す IReader & の参照先として必要 ReadBuffer buffer_; + std::time_t lastActivityTime_; }; #endif diff --git a/src/lib/utils/time.cpp b/src/lib/utils/time.cpp new file mode 100644 index 00000000..72acd99e --- /dev/null +++ b/src/lib/utils/time.cpp @@ -0,0 +1,13 @@ +#include "time.hpp" + +namespace utils { + + std::time_t Time::getCurrentTime() { + return std::time(NULL); + } + + double Time::diffTimeSeconds(std::time_t end, std::time_t start) { + return std::difftime(end, start); + } + +} diff --git a/src/lib/utils/time.hpp b/src/lib/utils/time.hpp new file mode 100644 index 00000000..7fda74e0 --- /dev/null +++ b/src/lib/utils/time.hpp @@ -0,0 +1,16 @@ +#ifndef TIME_HPP +#define TIME_HPP + +#include + +namespace utils { + + class Time { + public: + static std::time_t getCurrentTime(); + static double diffTimeSeconds(std::time_t end, std::time_t start); + }; + +} + +#endif