diff --git a/.github/labeler.yml b/.github/labeler.yml new file mode 100644 index 00000000..f18582c1 --- /dev/null +++ b/.github/labeler.yml @@ -0,0 +1,14 @@ +labels: + code-change: + - changed-files: + - any-glob-to-any-file: + - "src/**" + - "include/**" + - "priv/**" + - "**/*.erl" + - "**/*.hrl" + - "**/*.app.src" + - "**/*.app" + documentation: + - changed-files: + - '*.md' diff --git a/guides/multi-app.md b/guides/multi-app.md index 89b9b410..b858b53d 100644 --- a/guides/multi-app.md +++ b/guides/multi-app.md @@ -32,3 +32,14 @@ There's currently two different options available and they works in the same way ]} ... ``` + +## Pragmatically starting other nova applications + +### Starting an application + +You can also start other nova applications pragmatically by calling `nova_sup:add_application/2` to add another nova application to your supervision tree. The routes will automatically be added to the routing-module. + + +## Stopping an application + +To stop a nova application you can call `nova_sup:remove_application/1` with the name of the application you want to stop. Use this with caution since calling this method all routes for all other applications will be removed and re-added in order to filter out the one removed. diff --git a/rebar.config b/rebar.config index e7a6c92a..90c308a2 100644 --- a/rebar.config +++ b/rebar.config @@ -14,7 +14,6 @@ {cowboy, "2.13.0"}, {erlydtl, "0.14.0"}, {jhn_stdlib, "5.4.0"}, - {routing_tree, "1.0.11"}, {thoas, "1.2.1"} ]}. diff --git a/src/nova_request.erl b/src/nova_request.erl new file mode 100644 index 00000000..f3fdd24b --- /dev/null +++ b/src/nova_request.erl @@ -0,0 +1,143 @@ +%%%------------------------------------------------------------------- +%%% @author Niclas Axelsson +%%% @copyright (C) 2024, Niclas Axelsson +%%% @doc +%%% +%%% @end +%%% Created : 22 Dec 2024 by Niclas Axelsson +%%%------------------------------------------------------------------- +-module(nova_request). + +-behaviour(gen_server). + +%% API +-export([ + start_link/0 + ]). + +%% gen_server callbacks +-export([ + init/1, + handle_call/3, + handle_cast/2, + handle_info/2, + terminate/2, + code_change/3 + ]). + +-define(SERVER, ?MODULE). + +-record(state, {}). + +%%%=================================================================== +%%% API +%%%=================================================================== + +%%-------------------------------------------------------------------- +%% @doc +%% Starts the server +%% @end +%%-------------------------------------------------------------------- +-spec start_link() -> {ok, Pid :: pid()} | + {error, Error :: {already_started, pid()}} | + {error, Error :: term()} | + ignore. +start_link() -> + gen_server:start_link({local, ?SERVER}, ?MODULE, [], []). + +%%%=================================================================== +%%% gen_server callbacks +%%%=================================================================== + +%%-------------------------------------------------------------------- +%% @private +%% @doc +%% Initializes the server +%% @end +%%-------------------------------------------------------------------- +-spec init(Args :: term()) -> {ok, State :: term()} | + {ok, State :: term(), Timeout :: timeout()} | + {ok, State :: term(), hibernate} | + {stop, Reason :: term()} | + ignore. +init([]) -> + process_flag(trap_exit, true), + {ok, #state{}}. + +%%-------------------------------------------------------------------- +%% @private +%% @doc +%% Handling call messages +%% @end +%%-------------------------------------------------------------------- +-spec handle_call(Request :: term(), From :: {pid(), term()}, State :: term()) -> + {reply, Reply :: term(), NewState :: term()} | + {reply, Reply :: term(), NewState :: term(), Timeout :: timeout()} | + {reply, Reply :: term(), NewState :: term(), hibernate} | + {noreply, NewState :: term()} | + {noreply, NewState :: term(), Timeout :: timeout()} | + {noreply, NewState :: term(), hibernate} | + {stop, Reason :: term(), Reply :: term(), NewState :: term()} | + {stop, Reason :: term(), NewState :: term()}. +handle_call(_Request, _From, State) -> + Reply = ok, + {reply, Reply, State}. + +%%-------------------------------------------------------------------- +%% @private +%% @doc +%% Handling cast messages +%% @end +%%-------------------------------------------------------------------- +-spec handle_cast(Request :: term(), State :: term()) -> + {noreply, NewState :: term()} | + {noreply, NewState :: term(), Timeout :: timeout()} | + {noreply, NewState :: term(), hibernate} | + {stop, Reason :: term(), NewState :: term()}. +handle_cast(_Request, State) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% @private +%% @doc +%% Handling all non call/cast messages +%% @end +%%-------------------------------------------------------------------- +-spec handle_info(Info :: timeout() | term(), State :: term()) -> + {noreply, NewState :: term()} | + {noreply, NewState :: term(), Timeout :: timeout()} | + {noreply, NewState :: term(), hibernate} | + {stop, Reason :: normal | term(), NewState :: term()}. +handle_info(_Info, State) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% @private +%% @doc +%% This function is called by a gen_server when it is about to +%% terminate. It should be the opposite of Module:init/1 and do any +%% necessary cleaning up. When it returns, the gen_server terminates +%% with Reason. The return value is ignored. +%% @end +%%-------------------------------------------------------------------- +-spec terminate(Reason :: normal | shutdown | {shutdown, term()} | term(), + State :: term()) -> any(). +terminate(_Reason, _State) -> + ok. + +%%-------------------------------------------------------------------- +%% @private +%% @doc +%% Convert process state when code is changed +%% @end +%%-------------------------------------------------------------------- +-spec code_change(OldVsn :: term() | {down, term()}, + State :: term(), + Extra :: term()) -> {ok, NewState :: term()} | + {error, Reason :: term()}. +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== diff --git a/src/nova_router.erl b/src/nova_router.erl index 414e8b76..176175d6 100644 --- a/src/nova_router.erl +++ b/src/nova_router.erl @@ -25,15 +25,16 @@ %% Expose the router-callback routes/1, - %% Modulates the routes-table - add_routes/2, - %% Fetch information about the routing table plugins/0, - compiled_apps/0 + compiled_apps/0, + + %% Modulates the routes-table + add_routes/1, + add_routes/2, + remove_application/1 ]). --include_lib("routing_tree/include/routing_tree.hrl"). -include_lib("kernel/include/logger.hrl"). -include("../include/nova_router.hrl"). -include("../include/nova.hrl"). @@ -56,15 +57,21 @@ compiled_apps() -> StorageBackend = application:get_env(nova, dispatch_backend, persistent_term), StorageBackend:get(?NOVA_APPS, []). + +%% TODO! We need to implement a way to get and remove plugins for a path plugins() -> StorageBackend = application:get_env(nova, dispatch_backend, persistent_term), StorageBackend:get(?NOVA_PLUGINS, []). --spec compile(Apps :: [atom() | {atom(), map()}]) -> host_tree(). +-spec compile(Apps :: [atom() | {atom(), map()}]) -> nova_routing_trie:trie(). compile(Apps) -> UseStrict = application:get_env(nova, use_strict_routing, false), - Dispatch = compile(Apps, routing_tree:new(#{use_strict => UseStrict, convert_to_binary => true}), #{}), StorageBackend = application:get_env(nova, dispatch_backend, persistent_term), + + StoredDispatch = StorageBackend:get(nova_dispatch, + nova_routing_trie:new(#{options => #{strict => UseStrict}})), + Dispatch = compile(Apps, StoredDispatch, #{}), + %% Write the updated dispatch to storage StorageBackend:put(nova_dispatch, Dispatch), Dispatch. @@ -74,7 +81,7 @@ compile(Apps) -> execute(Req = #{host := Host, path := Path, method := Method}, Env) -> StorageBackend = application:get_env(nova, dispatch_backend, persistent_term), Dispatch = StorageBackend:get(nova_dispatch), - case routing_tree:lookup(Host, Path, Method, Dispatch) of + case nova_routing_trie:find(Host, Path, Method, Dispatch) of {error, not_found} -> logger:debug("Path ~p not found for ~p in ~p", [Path, Method, Host]), render_status_page('_', 404, #{error => "Not found in path"}, Req, Env); @@ -121,7 +128,7 @@ execute(Req = #{host := Host, path := Path, method := Method}, Env) -> } }; Error -> - ?LOG_ERROR(#{reason => <<"Unexpected return from routing_tree:lookup/4">>, + ?LOG_ERROR(#{reason => <<"Unexpected return from nova_routing_trie:lookup/4">>, return_object => Error}), render_status_page(Host, 404, #{error => Error}, Req, Env) end. @@ -138,7 +145,25 @@ lookup_url(Host, Path, Method) -> lookup_url(Host, Path, Method, Dispatch). lookup_url(Host, Path, Method, Dispatch) -> - routing_tree:lookup(Host, Path, Method, Dispatch). + nova_routing_trie:lookup(Host, Path, Method, Dispatch). + + +%%-------------------------------------------------------------------- +%% @doc +%% Works the same way as add_routes/2 but with the exception that you +%% don't need to provide the routes explicitly. When using this it's +%% expected that there's a routing-module associated with the application. +%% Eg. for the application 'test' the corresponding router would then be +%% 'test_router'. Read more about routers in the official documentation. +%% @end +%%-------------------------------------------------------------------- +-spec add_routes(App :: atom()) -> ok. +add_routes(App) -> + Router = erlang:list_to_atom(io_lib:format("~s_router", [App])), + Env = nova:get_environment(), + %% Call the router + Routes = Router:routes(Env), + add_routes(App, Routes). %%-------------------------------------------------------------------- %% @doc @@ -177,6 +202,25 @@ add_routes(App, Routes) -> throw({error, {invalid_routes, App, Routes}}). +%%-------------------------------------------------------------------- +%% @doc +%% Remove all routes associated with the given application. +%% @end +%%-------------------------------------------------------------------- +-spec remove_application(Application :: atom()) -> ok. +remove_application(Application) when is_atom(Application) -> + Dispatch = persistent_term:get(nova_dispatch), + %% Remove all routes for this application + {ok, Dispatch0} = + nova_routing_trie:foldl(Dispatch, + fun(R) -> + [ X || X = {_Host, _Prefix, #nova_handler_value{app = App}} <- R, + App =/= Application ] + end), + persistent_term:put(nova_dispatch, Dispatch0), + ok. + + %%%%%%%%%%%%%%%%%%%%%%%% %% INTERNAL FUNCTIONS %% %%%%%%%%%%%%%%%%%%%%%%%% @@ -201,7 +245,7 @@ apply_callback(Module, Function, Args) -> [] end. --spec compile(Apps :: [atom() | {atom(), map()}], Dispatch :: host_tree(), Options :: map()) -> host_tree(). +-spec compile(Apps :: [atom() | {atom(), map()}], Dispatch :: nova_routing_trie:trie(), Options :: map()) -> nova_routing_trie:trie(). compile([], Dispatch, _Options) -> Dispatch; compile([{App, Options}|Tl], Dispatch, GlobalOptions) -> compile([App|Tl], Dispatch, maps:merge(Options, GlobalOptions)); @@ -409,7 +453,7 @@ render_status_page(Host, StatusCode, Data, Req, Env) -> StorageBackend = application:get_env(nova, dispatch_backend, persistent_term), Dispatch = StorageBackend:get(nova_dispatch), {Req0, Env0} = - case routing_tree:lookup(Host, StatusCode, '_', Dispatch) of + case nova_routing_trie:find(Host, StatusCode, '_', Dispatch) of {error, _} -> %% Render nova page if exists - We need to determine where to find this path? {Req, Env#{app => nova, @@ -433,7 +477,7 @@ render_status_page(Host, StatusCode, Data, Req, Env) -> insert(Host, Path, Combinator, Value, Tree) -> - try routing_tree:insert(Host, Path, Combinator, Value, Tree) of + try nova_routing_trie:insert(Host, Path, Combinator, Value, Tree) of Tree0 -> Tree0 catch throw:Exception -> @@ -499,6 +543,8 @@ routes(_) -> -compile(export_all). %% Export all functions for testing purpose -include_lib("eunit/include/eunit.hrl"). - +compile_empty_test() -> + Dispatch = compile([]), + ?assertEqual(nova_routing_trie:new(#{options => #{strict => false}}), Dispatch). -endif. diff --git a/src/nova_routing_trie.erl b/src/nova_routing_trie.erl new file mode 100644 index 00000000..97c92677 --- /dev/null +++ b/src/nova_routing_trie.erl @@ -0,0 +1,984 @@ +-module(nova_routing_trie). + +-export([ + new/0, + new/1, + + %% Transform / rebuild from routes + foldl/2, + + %% Insert + insert/4, + insert/5, + insert/6, + + %% Membership + member/2, + member/3, + member/4, + + %% Unified lookup + lookup/2, + lookup/3, + lookup/4, + + to_list/1, + from_list/1 + ]). + +%% Host-aware trie: +%% - 'trie()' is the host root +%% - each host key maps to a per-host trie_node() (routing tree) +-opaque trie() :: #{ + options := map(), %% global/root options + hosts := #{ host_key() => trie_node() } + }. + +%% Per-host routing tree node +%% - children: nested segments +%% - terminal: map MethodBin => Payload +-opaque trie_node() :: #{ + options := map(), %% options for this node + children := #{ child_key() => trie_node() }, + plugins := #{ + pre_request := [module()], + post_request := [module()] + }, + terminal := #{ method() => payload() } + }. + +-type child_key() :: binary() | {wild, binary()}. + +-type host_key() :: '_' | binary(). +-type host_in() :: '_' | binary() | list() | atom(). + +%% Stored method representation: uppercase binary, e.g. <<"GET">>, <<"ALL">> +-type method() :: binary(). + +%% External method input: atom (get), binary (<<"GET">>), or string +-type method_in() :: method() | atom() | list(). + +-type payload() :: term(). + +%% A "route" is the canonical unit we can rebuild the trie from. +%% +%% Supported input shapes: +%% {Path, Method, Payload} +%% {Path, Method, Payload, Opts} +%% {Host, Path, Method, Payload} +%% {Host, Path, Method, Payload, Opts} +%% +%% Where: +%% - Host is optional (defaults to '_') +%% - Opts is optional (defaults to #{}) and is merged into trie options for that insertion +-type route() :: + {iodata(), method_in(), payload()} | + {iodata(), method_in(), payload(), map()} | + {host_in(), iodata(), method_in(), payload()} | + {host_in(), iodata(), method_in(), payload(), map()}. + +-type conflict() :: #{ + reason := atom(), + at := [child_key()], + existing := child_key() | undefined, + incoming := child_key() | undefined, + conflicts_with := binary(), + incoming_path := binary(), + method := method(), + existing_methods:= [method()] + }. + +-export_type([trie/0, trie_node/0, method/0]). + +%%==================================================================== +%% API +%%==================================================================== + +%%-------------------------------------------------------------------- +%% Constructors +%%-------------------------------------------------------------------- + +-spec new() -> trie(). +new() -> + new(#{}). + +-spec new(map()) -> trie(). +new(Opts) when is_map(Opts) -> + #{options => Opts, hosts => #{}}. + +%% Internal node constructor (per-host trees) +-spec new_node() -> trie_node(). +new_node() -> + new_node(#{}). + +-spec new_node(map()) -> trie_node(). +new_node(Opts) when is_map(Opts) -> + #{children => #{}, terminal => #{}, options => Opts, plugins => #{pre_request => [], + post_request => []}}. + +%%-------------------------------------------------------------------- +%% Host helpers +%%-------------------------------------------------------------------- + +-spec norm_host(host_in()) -> host_key(). +norm_host('_') -> + '_'; +norm_host(Host) when is_binary(Host) -> + Host; +norm_host(Host) when is_list(Host) -> + list_to_binary(Host); +norm_host(Host) when is_atom(Host) -> + list_to_binary(atom_to_list(Host)). + +%%-------------------------------------------------------------------- +%% Insert +%% +%% Canonical full form: +%% insert(Host, Path, Method, Payload, Trie, Opts) +%% +%% Convenience: +%% insert(Path, Method, Payload, Trie). % Host='_', Opts=#{} +%% insert(Host, Path, Method, Payload, Trie). % Opts=#{} +%%-------------------------------------------------------------------- + +-spec insert(iodata(), method_in(), payload(), trie()) -> + {ok, trie()} | {error, conflict, conflict()}. +insert(Path, Method, Payload, Trie) -> + insert('_', Path, Method, Payload, Trie, #{}). + +-spec insert(host_in(), iodata(), method_in(), payload(), trie()) -> + {ok, trie()} | {error, conflict, conflict()}. +insert(HostIn, Path, Method, Payload, Trie) -> + insert(HostIn, Path, Method, Payload, Trie, #{}). + +-spec insert(host_in(), iodata(), method_in(), payload(), trie(), map()) -> + {ok, trie()} | {error, conflict, conflict()}. +insert(HostIn, Path, Method, Payload, Trie0, Opts) -> + Host = norm_host(HostIn), + RootOpts0 = maps:get(options, Trie0, #{}), + RootOpts1 = maps:merge(RootOpts0, Opts), + Trie1 = Trie0#{options => RootOpts1}, + Hosts = maps:get(hosts, Trie1, #{}), + {Trie2, HostTrie0} = + case maps:get(Host, Hosts, undefined) of + undefined -> + HostTrie = new_node(RootOpts1), + Hosts1 = Hosts#{Host => HostTrie}, + {Trie1#{hosts := Hosts1}, HostTrie}; + HostTrie -> + {Trie1, HostTrie} + end, + case insert_inner(Method, Path, Payload, HostTrie0, RootOpts1) of + {ok, HostTrie1} -> + Hosts2 = maps:get(hosts, Trie2, #{}), + Hosts3 = Hosts2#{Host => HostTrie1}, + {ok, Trie2#{hosts := Hosts3}}; + {error, conflict, Conf} -> + {error, conflict, Conf} + end. + +%% Internal: insert into a single host's routing trie +-spec insert_inner(method_in(), iodata(), payload(), trie_node(), map()) -> + {ok, trie_node()} | {error, conflict, conflict()}. +insert_inner(Method, Path, Payload, TrieNode, Opts) -> + Opts0 = maps:get(options, TrieNode, #{}), + Opts1 = maps:merge(Opts0, Opts), + Trie1 = TrieNode#{options := Opts1}, + + M = norm_method(Method), + Segs = segs_for_insert(Path), + Strict = maps:get(strict, Opts1, false), + + case Strict of + true -> + %% Strict: conflicts become errors + safe_insert_segs(M, Payload, Segs, Trie1, [], Segs); + false -> + %% Non-strict: run safe_insert_segs only for detection, + %% then always perform regular insert_segs. Conflicts -> warnings. + case safe_insert_segs(M, Payload, Segs, Trie1, [], Segs) of + {error, conflict, Conf} -> + warn_conflict(Conf), + {ok, insert_segs(M, Payload, Segs, Trie1)}; + {ok, _TmpNode} -> + {ok, insert_segs(M, Payload, Segs, Trie1)} + end + end. + +%%-------------------------------------------------------------------- +%% Warnings (non-strict conflicts) +%%-------------------------------------------------------------------- +warn_conflict(Conf) -> + Reason = maps:get(reason, Conf), + Method = maps:get(method, Conf), + IncomingPath = maps:get(incoming_path, Conf), + ConflictsWith = maps:get(conflicts_with, Conf), + io:format( + "routing_trie warning (~p): ~s ~s conflicts with ~s~n", + [Reason, render_method(Method), IncomingPath, ConflictsWith] + ). + +%%-------------------------------------------------------------------- +%% Membership (uses lookup) +%%-------------------------------------------------------------------- + +-spec member(iodata(), trie()) -> boolean(). +member(Path, Trie) -> + member(<<"ALL">>, Path, Trie). + +-spec member(method_in(), iodata(), trie()) -> boolean(). +member(Method0, Path, Trie) -> + member(Method0, '_', Path, Trie). + +-spec member(method_in(), host_in(), iodata(), trie()) -> boolean(). +member(Method0, HostIn, Path, Trie) -> + case lookup(Method0, HostIn, Path, Trie) of + {ok, _Node, _Payload, _Binds} -> true; + error -> false + end. + +%%-------------------------------------------------------------------- +%% Unified lookup +%% +%% - Performs host selection (with '_' fallback) +%% - Performs path + wildcard matching +%% - Chooses payload according to Method (method-specific or <<"ALL">>) +%% - Returns: +%% error +%% | {ok, Node, Payload, Bindings} +%%-------------------------------------------------------------------- + +-spec lookup(iodata(), trie()) -> + {ok, trie_node(), term(), #{binary() => binary()}} | error. +lookup(Path, Trie) -> + lookup(<<"ALL">>, '_', Path, Trie). + +-spec lookup(method_in(), iodata(), trie()) -> + {ok, trie_node(), term(), #{binary() => binary()}} | error. +lookup(Method0, Path, Trie) -> + lookup(Method0, '_', Path, Trie). + +-spec lookup(method_in(), host_in(), iodata(), trie()) -> + {ok, trie_node(), term(), #{binary() => binary()}} | error. +lookup(Method0, HostIn, Path, Trie = #{hosts := Hosts}) -> + Host = norm_host(HostIn), + case maps:get(Host, Hosts, undefined) of + undefined when Host =/= '_' -> + lookup(Method0, '_', Path, Trie); + undefined -> + error; + HostTrie -> + M = norm_method(Method0), + Segs = segs_for_match(Path), + case do_match(Segs, HostTrie, #{}) of + {ok, Node, Binds} -> + case method_payload(M, Node) of + {ok, Payload} -> {ok, Node, Payload, Binds}; + error -> error + end; + _ -> + error + end + end. + + +%%-------------------------------------------------------------------- +%% to_list: flattened host view (ignores payloads) +%%-------------------------------------------------------------------- + +-spec to_list(trie()) -> [binary()]. +to_list(Trie) -> + Hosts = maps:get(hosts, Trie, #{}), + maps:fold( + fun(_Host, HostTrie, Acc) -> + gather(HostTrie, [], Acc) + end, [], Hosts). + +%%-------------------------------------------------------------------- +%% from_list: build a routing trie from a list of routes +%%-------------------------------------------------------------------- + +-spec from_list([route()]) -> {ok, trie()} | {error, conflict, conflict()}. +from_list(Routes) when is_list(Routes) -> + from_list(Routes, #{}). + +%%-------------------------------------------------------------------- +%% foldl: rebuild the trie by transforming its extracted routes +%% +%% Function :: fun(([route()]) -> [route()]) +%% +%% 1) Extracts routes from the trie (including payloads) +%% 2) Calls Function(Routes) +%% 3) Rebuilds a new trie from the returned routes +%%-------------------------------------------------------------------- + +-spec foldl(trie(), fun(([route()]) -> [route()])) -> + {ok, trie()} | {error, conflict, conflict()}. +foldl(Trie0, Fun) when is_map(Trie0), is_function(Fun, 1) -> + Routes0 = routes(Trie0), + Routes1 = Fun(Routes0), + case is_list(Routes1) of + true -> + from_list(Routes1, maps:get(options, Trie0, #{})); + false -> + erlang:error({badreturn, {foldl, Fun, Routes1}}) + end. + +%%==================================================================== +%% Internal functions (methods, segments, conflicts, matching) +%%==================================================================== + +%%-------------------------------------------------------------------- +%% from_list helpers +%%-------------------------------------------------------------------- + +-spec from_list([route()], map()) -> {ok, trie()} | {error, conflict, conflict()}. +from_list(Routes, RootOpts) when is_list(Routes), is_map(RootOpts) -> + lists:foldl( + fun + (Route, {ok, TrieAcc}) -> + insert_route(Route, TrieAcc); + (_Route, Err={error, conflict, _Conf}) -> + Err + end, + {ok, new(RootOpts)}, + Routes + ). + +-spec insert_route(route(), trie()) -> {ok, trie()} | {error, conflict, conflict()}. +insert_route({Path, Method, Payload}, Trie) -> + insert(Path, Method, Payload, Trie); +insert_route({Path, Method, Payload, Opts}, Trie) when is_map(Opts) -> + insert('_', Path, Method, Payload, Trie, Opts); +insert_route({Host, Path, Method, Payload}, Trie) -> + insert(Host, Path, Method, Payload, Trie); +insert_route({Host, Path, Method, Payload, Opts}, Trie) when is_map(Opts) -> + insert(Host, Path, Method, Payload, Trie, Opts); +insert_route(Other, _Trie) -> + erlang:error({bad_route, Other}). + +%%-------------------------------------------------------------------- +%% Route extraction (used by foldl/2) +%%-------------------------------------------------------------------- + +-spec routes(trie()) -> [route()]. +routes(Trie) -> + Hosts = maps:get(hosts, Trie, #{}), + maps:fold( + fun(Host, HostTrie, Acc) -> + gather_routes(HostTrie, [], Host, Acc) + end, + [], + Hosts + ). + +gather_routes(Node, AccSegs, Host, Acc0) -> + Term = maps:get(terminal, Node, #{}), + Acc1 = + maps:fold( + fun(M, Payload, A) -> + Path = render_path(AccSegs), + [{Host, Path, M, Payload} | A] + end, + Acc0, + Term + ), + Cs = maps:get(children, Node, #{}), + maps:fold( + fun(K, Child, A) -> + gather_routes(Child, AccSegs ++ [K], Host, A) + end, + Acc1, + Cs + ). + +%% Methods + +-spec norm_method(method_in()) -> method(). +norm_method(M) when is_binary(M) -> + %% Assume already normalized (e.g. <<"GET">>, <<"POST">>, <<"ALL">>) + M; +norm_method(M) when is_atom(M) -> + list_to_binary(string:uppercase(atom_to_list(M))); +norm_method(M) when is_list(M) -> + list_to_binary(string:uppercase(M)). + +terminal_add(M, Payload, Node0=#{terminal := Term0}) -> + Node0#{terminal := Term0#{ M => Payload }}. + +%% Resolve payload for a given method at a terminal node +-spec method_payload(method(), trie_node()) -> {ok, term()} | error. +method_payload(M, #{terminal := Term}) -> + case M of + <<"ALL">> -> + case maps:find(<<"ALL">>, Term) of + {ok, Payload} -> + {ok, Payload}; + error -> + case maps:to_list(Term) of + [] -> error; + [{_, P} | _] -> {ok, P} + end + end; + _ -> + case maps:find(M, Term) of + {ok, Payload} -> + {ok, Payload}; + error -> + case maps:find(<<"ALL">>, Term) of + {ok, Payload} -> {ok, Payload}; + error -> error + end + end + end. + +%% method_member(M, Node) -> +%% case method_payload(M, Node) of +%% {ok, _} -> true; +%% error -> false +%% end. + +methods_list(#{terminal := Term}) -> + [K || {K, _} <- maps:to_list(Term)]. + +render_method(M) when is_binary(M) -> + M; +render_method(M) -> + norm_method(M). + +%% Segments and paths + +segs_for_insert(Path) when is_list(Path) -> + segs_for_insert(list_to_binary(Path)); +segs_for_insert(Path) when is_binary(Path) -> + Parts = [S || S <- binary:split(Path, <<"/">>, [global, trim]), S =/= <<>>], + [<<"/">> | [ to_key(S) || S <- Parts ]]. + +segs_for_match(Path) when is_list(Path) -> + segs_for_match(list_to_binary(Path)); +segs_for_match(Path) when is_binary(Path) -> + Parts = [S || S <- binary:split(Path, <<"/">>, [global, trim]), S =/= <<>>], + [<<"/">> | Parts]. + +to_key(<<":", Rest/binary>>) -> {wild, Rest}; +to_key(Bin) -> Bin. + +render_key({wild, Name}) -> <<":", Name/binary>>; +render_key(Bin) -> Bin. + +%% Older-Erlang-safe binary join +-spec join_with_sep([binary()], binary()) -> binary(). +join_with_sep([], _Sep) -> + <<>>; +join_with_sep([First | Rest], Sep) -> + iolist_to_binary( + lists:foldl( + fun(Bin, Acc) -> [Acc, Sep, Bin] end, + First, + Rest + )). + +render_path(Segs) -> + case Segs of + [] -> <<"/">>; + [<<"/">>] -> <<"/">>; + [<<"/">> | Rest] -> + RestBins = [render_key(S) || S <- Rest], + <<"/", (join_with_sep(RestBins, <<"/">>))/binary>>; + _ -> + join_with_sep([render_key(S) || S <- Segs], <<"/">>) + end. + +%% Insert without conflict checking (single-host trie) + +insert_segs(M, Payload, [], N0) -> + terminal_add(M, Payload, N0); +insert_segs(M, Payload, [K | Rest], N0) -> + Cs0 = maps:get(children, N0), + Child0 = maps:get(K, Cs0, new_node()), + Child1 = insert_segs(M, Payload, Rest, Child0), + N0#{children := maps:put(K, Child1, Cs0)}. + +%% Conflict helpers + +%% Find any wildcard child +find_wild_child(Cs) -> + maps:fold( + fun + ({wild, Var}, Child, none) -> {wild, Var, Child}; + (_K, _Child, Acc) -> Acc + end, none, Cs). + +%% Find static child that will be affected by new wildcard at same depth +find_static_overshadow_child(Cs, M) -> + maps:fold( + fun + ({wild, _}, _Child, Acc) -> + Acc; + (K, Child, none) -> + Ms = methods_list(Child), + case overshadow_methods(M, Ms) of + true -> {found, {K, Child, Ms}}; + false -> none + end; + (_K, _Child, Acc) -> + Acc + end, none, Cs). + +%% Does method M conflict (overlap) with existing methods? +overshadow_methods(M, ExistingMs) -> + case M of + <<"ALL">> -> + ExistingMs =/= []; + _ -> + lists:member(M, ExistingMs) orelse + lists:member(<<"ALL">>, ExistingMs) + end. + +%% Safe insert with conflict checking (single-host trie) +safe_insert_segs(M, Payload, [], N0, Prefix, Full) -> + Ms = methods_list(N0), + case lists:member(M, Ms) of + true -> + {error, conflict, #{ + reason => duplicate_pattern, + at => Prefix, + existing => undefined, + incoming => undefined, + conflicts_with => render_path(Prefix), + incoming_path => render_path(Full), + method => M, + existing_methods=> Ms + }}; + false -> + case lists:member(<<"ALL">>, Ms) of + true when M =/= <<"ALL">> -> + {error, conflict, #{ + reason => duplicate_due_to_all, + at => Prefix, + existing => undefined, + incoming => undefined, + conflicts_with => render_path(Prefix), + incoming_path => render_path(Full), + method => M, + existing_methods=> Ms + }}; + _ when M =:= <<"ALL">>, Ms =/= [] -> + {error, conflict, #{ + reason => duplicate_due_to_existing_methods, + at => Prefix, + existing => undefined, + incoming => undefined, + conflicts_with => render_path(Prefix), + incoming_path => render_path(Full), + method => M, + existing_methods=> Ms + }}; + _ -> + {ok, terminal_add(M, Payload, N0)} + end + end; + +%% Branch when we insert a wildcard segment +safe_insert_segs(M, Payload, [K = {wild, NewVar} | Rest], N0, Prefix, Full) -> + Cs0 = maps:get(children, N0), + + %% 1) Wildcard overshadowing existing static terminal at same depth + case Rest of + [] -> + case find_static_overshadow_child(Cs0, M) of + {found, {ExistingKey, _ChildN, ExistingMs}} -> + {error, conflict, #{ + reason => overshadowing_route, + at => Prefix, + existing => ExistingKey, + incoming => K, + conflicts_with => render_path(Prefix ++ [ExistingKey]), + incoming_path => render_path(Full), + method => M, + existing_methods=> ExistingMs + }}; + none -> + wildcard_name_insert(M, Payload, NewVar, Rest, + N0, Prefix, Full, Cs0) + end; + _ -> + wildcard_name_insert(M, Payload, NewVar, Rest, + N0, Prefix, Full, Cs0) + end; + +%% Branch when we insert a static segment +safe_insert_segs(M, Payload, [K | Rest], N0, Prefix, Full) -> + Cs0 = maps:get(children, N0), + + %% Static overshadowing existing wildcard terminal at same depth + case Rest of + [] -> + case find_wild_child(Cs0) of + {wild, ExistingVar, ChildN} -> + ExistingMs = methods_list(ChildN), + case overshadow_methods(M, ExistingMs) of + true -> + {error, conflict, #{ + reason => overshadowing_route, + at => Prefix, + existing => {wild, ExistingVar}, + incoming => K, + conflicts_with => render_path(Prefix ++ [{wild, ExistingVar}]), + incoming_path => render_path(Full), + method => M, + existing_methods=> ExistingMs + }}; + false -> + static_continue(M, Payload, K, Rest, + N0, Prefix, Full, Cs0) + end; + none -> + static_continue(M, Payload, K, Rest, + N0, Prefix, Full, Cs0) + end; + _ -> + static_continue(M, Payload, K, Rest, + N0, Prefix, Full, Cs0) + end. + +%% Helper: continue wildcard insertion after overshadow/name checks +wildcard_name_insert(M, Payload, NewVar, Rest, N0, Prefix, Full, Cs0) -> + ExistingWild = find_wild_child(Cs0), + case ExistingWild of + none -> + Child0 = new_node(), + case safe_insert_segs(M, Payload, Rest, Child0, + Prefix ++ [{wild, NewVar}], Full) of + {ok, Child1} -> + {ok, N0#{children := maps:put({wild, NewVar}, Child1, Cs0)}}; + {error, conflict, Info} -> + {error, conflict, Info} + end; + {wild, ExistingVar, ChildN} -> + if ExistingVar =:= NewVar -> + case safe_insert_segs(M, Payload, Rest, ChildN, + Prefix ++ [{wild, NewVar}], Full) of + {ok, Child1} -> + {ok, N0#{children := maps:put({wild, ExistingVar}, Child1, Cs0)}}; + {error, conflict, Info} -> + {error, conflict, Info} + end; + true -> + {error, conflict, #{ + reason => wildcard_name_conflict, + at => Prefix, + existing => {wild, ExistingVar}, + incoming => {wild, NewVar}, + conflicts_with => render_path(Prefix ++ [{wild, ExistingVar}]), + incoming_path => render_path(Prefix ++ [{wild, NewVar}] ++ Rest), + method => M, + existing_methods=> methods_list(ChildN) + }} + end + end. + +%% Helper: continue static insertion after overshadow check +static_continue(M, Payload, K, Rest, N0, Prefix, Full, Cs0) -> + Child0 = maps:get(K, Cs0, new_node()), + case safe_insert_segs(M, Payload, Rest, Child0, + Prefix ++ [K], Full) of + {ok, Child1} -> + {ok, N0#{children := maps:put(K, Child1, Cs0)}}; + {error, conflict, Info} -> + {error, conflict, Info} + end. + +%%-------------------------------------------------------------------- +%% Matching (runtime lookup) +%%-------------------------------------------------------------------- + +do_match([], N, Binds) -> + {ok, N, Binds}; +do_match([Seg | Rest], N, Binds0) -> + Cs = maps:get(children, N), + + %% 1) Exact first + Exact = case maps:find(Seg, Cs) of + {ok, C} -> + case do_match(Rest, C, Binds0) of + error -> error; + Ok -> Ok + end; + error -> error + end, + case Exact of + {ok, _, _} -> Exact; + error -> + %% 2) Wildcard fallback + case find_wild_child(Cs) of + none -> error; + {wild, VarName, C0} -> + do_match(Rest, C0, Binds0#{ VarName => Seg }) + end + end. + +%%-------------------------------------------------------------------- +%% to_list helpers +%%-------------------------------------------------------------------- + +gather(N, AccSegs, AccOut) -> + Ms = methods_list(N), + AccOut1 = + case Ms of + [] -> AccOut; + _ -> + Path = render_path(AccSegs), + MethodLines = + [<< (render_method(M))/binary, " ", Path/binary >> || M <- Ms], + MethodLines ++ AccOut + end, + Cs = maps:get(children, N), + maps:fold( + fun(K, Child, Out) -> + gather(Child, AccSegs ++ [K], Out) + end, AccOut1, Cs). + +%%==================================================================== +%% EUnit tests +%%==================================================================== +-ifdef(TEST). +-include_lib("eunit/include/eunit.hrl"). + +default_host_insert_and_lookup_test() -> + T0 = new(), + {ok, T1} = insert(<<"/users">>, get, payload1, T0), + + ?assert(member(get, <<"/users">>, T1)), + ?assertMatch({ok, _Node, payload1, #{}}, + lookup(get, <<"/users">>, T1)). + +binary_method_insert_and_lookup_test() -> + T0 = new(), + {ok, T1} = insert(<<"/users">>, <<"GET">>, payload1, T0), + + %% lookup using binary method + ?assertMatch({ok, _Node, payload1, #{}}, + lookup(<<"GET">>, <<"/users">>, T1)), + %% lookup using atom method + ?assertMatch({ok, _Node, payload1, #{}}, + lookup(get, <<"/users">>, T1)). + +wildcard_path_lookup_test() -> + T0 = new(), + {ok, T1} = insert('_', <<"/users/:id">>, get, payload1, T0, + #{strict => true}), + + ?assertMatch({ok, _Node, payload1, #{<<"id">> := <<"42">>}}, + lookup(get, <<"/users/42">>, T1)), + ?assertMatch({ok, _Node, payload1, #{<<"id">> := <<"abc">>}}, + lookup(get, <<"/users/abc">>, T1)). + +method_filtering_lookup_test() -> + T0 = new(), + {ok, T1} = insert(<<"/users">>, post, payload_post, T0), + + ?assertMatch({ok, _Node, payload_post, #{}}, + lookup(post, <<"/users">>, T1)), + ?assertEqual(error, + lookup(get, <<"/users">>, T1)). + +host_specific_only_test() -> + T0 = new(), + Host = <<"http://api.example.com">>, + + {ok, T1} = insert(Host, <<"/users">>, get, payload_host, T0), + + ?assertMatch({ok, _Node, payload_host, #{}}, + lookup(get, Host, <<"/users">>, T1)), + ?assertEqual(error, + lookup(get, + <<"http://other.example.com">>, + <<"/users">>, T1)). + +host_fallback_to_catchall_test() -> + T0 = new(), + %% insert only on '_' host + {ok, T1} = insert('_', <<"/users">>, get, payload_all, T0, #{}), + + %% should match when querying with another host due to fallback + ?assertMatch({ok, _Node, payload_all, #{}}, + lookup(get, + <<"http://api.example.com">>, + <<"/users">>, T1)). + +host_and_method_lookup_test() -> + T0 = new(), + Host = <<"http://api.example.com">>, + + {ok, T1} = insert(Host, <<"/users">>, post, payload_post, T0), + + ?assertMatch({ok, _Node, payload_post, #{}}, + lookup(post, Host, <<"/users">>, T1)), + ?assertEqual(error, + lookup(get, Host, <<"/users">>, T1)), + %% and also check that another host falls back to '_' only if '_' exists + ?assertEqual(error, + lookup(post, + <<"http://other.example.com">>, + <<"/users">>, T1)). + +strict_conflict_duplicate_pattern_test() -> + T0 = new(), + {ok, T1} = + insert('_', <<"/users/:id">>, get, payload1, T0, + #{strict => true}), + + {error, conflict, Conf} = + insert('_', <<"/users/:id">>, get, payload2, T1, + #{strict => true}), + + ?assertEqual(duplicate_pattern, maps:get(reason, Conf)), + ?assertEqual(<<"GET">>, maps:get(method, Conf)). + +overshadow_strict_conflict_static_then_wild_test() -> + T0 = new(), + {ok, T1} = insert('_', <<"/user/my_user">>, get, payload_static, T0, + #{strict => true}), + + {error, conflict, Conf} = + insert('_', <<"/user/:user_id">>, get, payload_wild, T1, + #{strict => true}), + + ?assertEqual(overshadowing_route, maps:get(reason, Conf)), + ?assertEqual(<<"GET">>, maps:get(method, Conf)), + ?assertEqual(<<"/user/my_user">>, maps:get(conflicts_with, Conf)), + ?assertEqual(<<"/user/:user_id">>, maps:get(incoming_path, Conf)). + +overshadow_strict_conflict_wild_then_static_test() -> + T0 = new(), + {ok, T1} = insert('_', <<"/user/:user_id">>, get, payload_wild, T0, + #{strict => true}), + + {error, conflict, Conf} = + insert('_', <<"/user/my_user">>, get, payload_static, T1, + #{strict => true}), + + ?assertEqual(overshadowing_route, maps:get(reason, Conf)), + ?assertEqual(<<"GET">>, maps:get(method, Conf)), + ?assertEqual(<<"/user/:user_id">>, maps:get(conflicts_with, Conf)), + ?assertEqual(<<"/user/my_user">>, maps:get(incoming_path, Conf)). + +overshadow_non_strict_warning_static_then_wild_test() -> + T0 = new(), + {ok, T1} = insert(<<"/user/my_user">>, get, payload_static, T0), + + %% This should only warn, not error + {ok, T2} = insert(<<"/user/:user_id">>, get, payload_wild, T1), + + %% /user/my_user should still match the static route + ?assertMatch({ok, _Node, payload_static, #{}}, + lookup(get, <<"/user/my_user">>, T2)), + %% and /user/other should match the wildcard route + ?assertMatch({ok, _Node, payload_wild, + #{<<"user_id">> := <<"other">>}}, + lookup(get, <<"/user/other">>, T2)). + +overshadow_non_strict_warning_wild_then_static_test() -> + T0 = new(), + {ok, T1} = insert(<<"/user/:user_id">>, get, payload_wild, T0), + + %% This should only warn, not error + {ok, T2} = insert(<<"/user/my_user">>, get, payload_static, T1), + + %% /user/my_user should match the static route + ?assertMatch({ok, _Node, payload_static, #{}}, + lookup(get, <<"/user/my_user">>, T2)), + %% and /user/other should match the wildcard route + ?assertMatch({ok, _Node, payload_wild, + #{<<"user_id">> := <<"other">>}}, + lookup(get, <<"/user/other">>, T2)). + +to_list_simple_test() -> + T0 = new(), + {ok, T1} = insert(<<"/users/:id">>, get, payload1, T0), + Lines = to_list(T1), + + %% We expect "GET /users/:id" in the list + ?assert(lists:member(<<"GET /users/:id">>, Lines)). + +lookup_returns_node_and_bindings_test() -> + T0 = new(), + {ok, T1} = insert(<<"localhost">>, <<"/user/:id">>, get, payload1, T0, + #{strict => true}), + {ok, Node, Payload, Binds} = lookup(get, <<"localhost">>, <<"/user/42">>, T1), + ?assert(is_map(Node)), + ?assertEqual(payload1, Payload), + ?assertMatch(#{<<"id">> := <<"42">>}, Binds). + +foldl_can_filter_routes_test() -> + T0 = new(), + {ok, T1} = insert(<<"/a">>, get, payload_a, T0), + {ok, T2} = insert(<<"/b">>, get, payload_b, T1), + + {ok, T3} = + foldl( + T2, + fun(Routes0) -> + [R || R = {_Host, Path, _M, _P} <- Routes0, + Path =/= <<"/b">>] + end + ), + + ?assertMatch({ok, _Node, payload_a, #{}}, + lookup(get, <<"/a">>, T3)), + ?assertEqual(error, + lookup(get, <<"/b">>, T3)). + + +foldl_can_rewrite_payloads_test() -> + T0 = new(), + {ok, T1} = insert(<<"/a">>, get, payload_a, T0), + {ok, T2} = insert(<<"/b">>, get, payload_b, T1), + + %% Byt payload för /a men låt /b vara oförändrad + {ok, T3} = + foldl( + T2, + fun(Routes0) -> + [case R of + {Host, <<"/a">>, <<"GET">>, payload_a} -> + {Host, <<"/a">>, <<"GET">>, payload_a_v2}; + _ -> + R + end || R <- Routes0] + end + ), + + ?assertMatch({ok, _NodeA, payload_a_v2, #{}}, + lookup(get, <<"/a">>, T3)), + ?assertMatch({ok, _NodeB, payload_b, #{}}, + lookup(get, <<"/b">>, T3)). + +foldl_can_rewrite_methods_test() -> + T0 = new(), + {ok, T1} = insert(<<"/a">>, get, payload_a, T0), + + %% Flytta routen från GET till POST (vi skickar 'post' som atom för att + %% samtidigt testa att from_list/insert normaliserar method korrekt) + {ok, T2} = + foldl( + T1, + fun(Routes0) -> + [case R of + {Host, <<"/a">>, <<"GET">>, Payload} -> + {Host, <<"/a">>, post, Payload}; + _ -> + R + end || R <- Routes0] + end + ), + + %% GET ska inte längre matcha + ?assertEqual(error, + lookup(get, <<"/a">>, T2)), + %% POST ska matcha och ge samma payload + ?assertMatch({ok, _Node, payload_a, #{}}, + lookup(post, <<"/a">>, T2)). + +-endif. diff --git a/src/nova_sup.erl b/src/nova_sup.erl index 980f6253..d78f141a 100644 --- a/src/nova_sup.erl +++ b/src/nova_sup.erl @@ -8,7 +8,12 @@ -behaviour(supervisor). %% API --export([start_link/0]). +-export([ + start_link/0, + add_application/2, + remove_application/1, + get_started_applications/0 + ]). %% Supervisor callbacks -export([init/1]). @@ -17,9 +22,20 @@ -include("../include/nova.hrl"). -define(SERVER, ?MODULE). --define(NOVA_LISTENER, nova_listener). + +-define(NOVA_LISTENER, fun(LApp, LPort) -> list_to_atom(atom_to_list(LApp) ++ integer_to_list(LPort)) end). -define(NOVA_STD_PORT, 8080). -define(NOVA_STD_SSL_PORT, 8443). +-define(NOVA_SUP_TABLE, nova_sup_table). +-define(COWBOY_LISTENERS, cowboy_listeners). + + +-record(nova_server, { + app :: atom(), + host :: inet:ip_address(), + port :: number(), + listener :: ranch:ref() + }). %%%=================================================================== @@ -36,6 +52,55 @@ start_link() -> supervisor:start_link({local, ?SERVER}, ?MODULE, []). +%%-------------------------------------------------------------------- +%% @doc +%% Add a Nova application. This can either be on the same cowboy server that +%% a previous application was started with, or a new one if the configuration +%% ie port is different. +%% +%% @end +%%-------------------------------------------------------------------- +-spec add_application(App :: atom(), Configuration :: map()) -> {ok, App :: atom(), + Host :: inet:ip_address(), Port :: number()} + | {error, Reason :: any()}. +add_application(App, Configuration) -> + setup_cowboy(App, Configuration). + +%%-------------------------------------------------------------------- +%% @doc +%% Get all started Nova applications. This will return a list of +%% #nova_server{} records that contains the application name, host, port +%% and listener reference. +%% +%% @end +%%-------------------------------------------------------------------- +-spec get_started_applications() -> [#{app => atom(), host => inet:ip_address(), port => number()}]. +get_started_applications() -> + %% Fetch all started applications from the ETS table + Apps = ets:tab2list(?NOVA_SUP_TABLE), + [ #{app => App, host => Host, port => Port} || + #nova_server{app = App, host = Host, port = Port} <- Apps ]. + +%%-------------------------------------------------------------------- +%% @doc +%% Remove a Nova application. This will stop the cowboy listener so request +%% to that application will not be handled anymore. +%% +%% @end +%%-------------------------------------------------------------------- +remove_application(App) -> + case ets:lookup(?NOVA_SUP_TABLE, App) of + [] -> + ?LOG_ERROR(#{msg => <<"Application not found">>, app => App}), + {error, not_found}; + [#nova_server{listener = Listener}] -> + ?LOG_NOTICE(#{msg => <<"Stopping cowboy listener">>, app => App, listener => Listener}), + cowboy:stop_listener(Listener), + ets:delete(?NOVA_SUP_TABLE, App), + %% Now we need to remove all routes associated with this listener + ok + end. + %%%=================================================================== %%% Supervisor callbacks %%%=================================================================== @@ -51,19 +116,20 @@ start_link() -> %% @end %%-------------------------------------------------------------------- init([]) -> + %% Initialize the ETS table for application state + ets:new(?NOVA_SUP_TABLE, [named_table, protected, set]), + %% This is a bit ugly, but we need to do this anyhow(?) SupFlags = #{strategy => one_for_one, intensity => 1, period => 5}, + %% Bootstrap the environment Environment = nova:get_environment(), - nova_pubsub:start(), ?LOG_NOTICE(#{msg => <<"Starting nova">>, environment => Environment}), - Configuration = application:get_env(nova, cowboy_configuration, #{}), - SessionManager = application:get_env(nova, session_manager, nova_session_ets), Children0 = [ @@ -77,7 +143,7 @@ init([]) -> false -> Children0 end, - setup_cowboy(Configuration), + setup_cowboy(), {ok, {SupFlags, Children}}. @@ -99,8 +165,17 @@ child(Id, Type, Mod) -> child(Id, Mod) -> child(Id, worker, Mod). -setup_cowboy(Configuration) -> - case start_cowboy(Configuration) of + +%%%------------------------------------------------------------------- +%%% Nova Cowboy setup +%%%------------------------------------------------------------------- +setup_cowboy() -> + CowboyConfiguration = application:get_env(nova, cowboy_configuration, #{}), + BootstrapApp = application:get_env(nova, bootstrap_application, undefined), + setup_cowboy(BootstrapApp, CowboyConfiguration). + +setup_cowboy(BootstrapApp, Configuration) -> + case start_cowboy(BootstrapApp, Configuration) of {ok, App, Host, Port} -> Host0 = inet:ntoa(Host), CowboyVersion = get_version(cowboy), @@ -114,97 +189,110 @@ setup_cowboy(Configuration) -> ?LOG_ERROR(#{msg => <<"Cowboy could not start">>, reason => Error}) end. --spec start_cowboy(Configuration :: map()) -> + +-spec start_cowboy(BootstrapApp :: atom(), Configuration :: map()) -> {ok, BootstrapApp :: atom(), Host :: string() | {integer(), integer(), integer(), integer()}, Port :: integer()} | {error, Reason :: any()}. -start_cowboy(Configuration) -> - Middlewares = [ - nova_router, %% Lookup routes - nova_plugin_handler, %% Handle pre-request plugins - nova_security_handler, %% Handle security - nova_handler, %% Controller - nova_plugin_handler %% Handle post-request plugins - ], - StreamH = [nova_stream_h, - cowboy_compress_h, - cowboy_stream_h], - StreamHandlers = maps:get(stream_handlers, Configuration, StreamH), - MiddlewareHandlers = maps:get(middleware_handlers, Configuration, Middlewares), - Options = maps:get(options, Configuration, #{compress => true}), - - %% Build the options map - CowboyOptions1 = Options#{middlewares => MiddlewareHandlers, - stream_handlers => StreamHandlers}, - - BootstrapApp = application:get_env(nova, bootstrap_application, undefined), - - %% Compile the routes - Dispatch = - case BootstrapApp of - undefined -> - ?LOG_ERROR(#{msg => <<"You need to define bootstrap_application option in configuration">>}), - throw({error, no_nova_app_defined}); - App -> - ExtraApps = application:get_env(App, nova_apps, []), - nova_router:compile([nova|[App|ExtraApps]]) - end, - - CowboyOptions2 = - case application:get_env(nova, use_persistent_term, true) of - true -> - CowboyOptions1; - _ -> - CowboyOptions1#{env => #{dispatch => Dispatch}} - end, - +start_cowboy(BootstrapApp, Configuration) -> + %% Determine if we have an already started cowboy on the host/port configuration Host = maps:get(ip, Configuration, { 0, 0, 0, 0}), - - case maps:get(use_ssl, Configuration, false) of - false -> - Port = maps:get(port, Configuration, ?NOVA_STD_PORT), - case cowboy:start_clear( - ?NOVA_LISTENER, - [{port, Port}, - {ip, Host}], - CowboyOptions2) of - {ok, _Pid} -> - {ok, BootstrapApp, Host, Port}; - Error -> - Error - end; + Port = maps:get(port, Configuration, ?NOVA_STD_PORT), + + Listeners = nova:get_env(?COWBOY_LISTENERS, []), + AlreadyStarted = lists:any(fun({X, Y}) -> X == Host andalso Y == Port end, Listeners), + + %% If yes we only need to add things to the dispatch + case AlreadyStarted of + true -> + %% A cowboy listener is already running on this host/port configuration - just add to the + %% dispatch. + logger:info(#{msg => <<"There's already a Cowboy listener running with the host/port config. Adding routes to dispatch.">>, host => Host, port => Port}), + ok; _ -> - case maps:get(ca_cert, Configuration, undefined) of - undefined -> - Port = maps:get(ssl_port, Configuration, ?NOVA_STD_SSL_PORT), - SSLOptions = maps:get(ssl_options, Configuration, #{}), - TransportOpts = maps:put(port, Port, SSLOptions), - TransportOpts1 = maps:put(ip, Host, TransportOpts), - - case cowboy:start_tls( - ?NOVA_LISTENER, maps:to_list(TransportOpts1), CowboyOptions2) of + %% Cowboy configuration + Middlewares = [ + nova_router, %% Lookup routes + nova_plugin_handler, %% Handle pre-request plugins + nova_security_handler, %% Handle security + nova_handler, %% Controller + nova_plugin_handler %% Handle post-request plugins + ], + StreamH = [ + nova_stream_h, + cowboy_compress_h, + cowboy_stream_h + ], + + %% Good debug message in case someone wants to double check which config they are running with + logger:debug(#{msg => <<"Configure cowboy">>, stream_handlers => StreamH, middlewares => Middlewares}), + + StreamHandlers = maps:get(stream_handlers, Configuration, StreamH), + MiddlewareHandlers = maps:get(middleware_handlers, Configuration, Middlewares), + Options = maps:get(options, Configuration, #{compress => true}), + + %% Build the options map + CowboyOptions1 = Options#{middlewares => MiddlewareHandlers, + stream_handlers => StreamHandlers}, + + %% Compile the routes + Dispatch = + case BootstrapApp of + undefined -> + ?LOG_ERROR(#{msg => <<"You need to define bootstrap_application option in configuration">>}), + throw({error, no_nova_app_defined}); + App -> + ExtraApps = application:get_env(App, nova_apps, []), + nova_router:compile([nova|[App|ExtraApps]]) + end, + + CowboyOptions2 = + case application:get_env(nova, use_persistent_term, true) of + true -> + CowboyOptions1; + _ -> + CowboyOptions1#{env => #{dispatch => Dispatch}} + end, + + case maps:get(use_ssl, Configuration, false) of + false -> + case cowboy:start_clear( + ?NOVA_LISTENER(BootstrapApp, Port), + [{port, Port}, + {ip, Host}], + CowboyOptions2) of {ok, _Pid} -> - ?LOG_NOTICE(#{msg => <<"Nova starting SSL">>, port => Port}), + nova:set_env(?COWBOY_LISTENERS, [{Host, Port}|Listeners]), + ets:insert(?NOVA_SUP_TABLE, #nova_server{ + app = BootstrapApp, + host = Host, + port = Port, + listener = ?NOVA_LISTENER(BootstrapApp, Port) + }), {ok, BootstrapApp, Host, Port}; Error -> - ?LOG_ERROR(#{msg => <<"Could not start cowboy with SSL">>, reason => Error}), Error end; - CACert -> - Cert = maps:get(cert, Configuration), - Port = maps:get(ssl_port, Configuration, ?NOVA_STD_SSL_PORT), - ?LOG_DEPRECATED(<<"0.10.3">>, <<"Use of use_ssl is deprecated, use ssl instead">>), + _ -> + SSLPort = maps:get(ssl_port, Configuration, ?NOVA_STD_SSL_PORT), + SSLOptions = maps:get(ssl_options, Configuration, #{}), + TransportOpts = maps:put(port, SSLPort, SSLOptions), + TransportOpts1 = maps:put(ip, Host, TransportOpts), + case cowboy:start_tls( - ?NOVA_LISTENER, [ - {port, Port}, - {ip, Host}, - {certfile, Cert}, - {cacertfile, CACert} - ], - CowboyOptions2) of + ?NOVA_LISTENER(BootstrapApp, SSLPort), + maps:to_list(TransportOpts1), CowboyOptions2) of {ok, _Pid} -> - ?LOG_NOTICE(#{msg => <<"Nova starting SSL">>, port => Port}), - {ok, BootstrapApp, Host, Port}; + ?LOG_NOTICE(#{msg => <<"Nova starting SSL">>, port => SSLPort}), + ets:insert(?NOVA_SUP_TABLE, #nova_server{ + app = BootstrapApp, + host = Host, + port = SSLPort, + listener = ?NOVA_LISTENER(BootstrapApp, SSLPort) + }), + nova:set_env(?COWBOY_LISTENERS, [{Host, SSLPort}|Listeners]), + {ok, BootstrapApp, Host, SSLPort}; Error -> + ?LOG_ERROR(#{msg => <<"Could not start cowboy with SSL">>, reason => Error}), Error end end