From 46f687806451e35eba4373f7e136a0663a5ed537 Mon Sep 17 00:00:00 2001 From: Tim Bradgate Date: Mon, 14 Apr 2025 00:47:06 +0100 Subject: [PATCH 01/11] Add stage direction styles configuration options --- client/package-lock.json | 64 ++ client/package.json | 1 + client/src/main.js | 8 + client/src/store/modules/script.js | 69 +++ client/src/views/show/config/ConfigScript.vue | 6 +- .../config/script/StageDirectionStyles.vue | 566 ++++++++++++++++++ ...a44e01459595_add_stage_direction_styles.py | 57 ++ server/controllers/api/show/script.py | 196 +++++- server/models/script.py | 21 + server/schemas/schemas.py | 10 +- server/utils/database.py | 16 +- 11 files changed, 1008 insertions(+), 6 deletions(-) create mode 100644 client/src/vue_components/show/config/script/StageDirectionStyles.vue create mode 100644 server/alembic_config/versions/a44e01459595_add_stage_direction_styles.py diff --git a/client/package-lock.json b/client/package-lock.json index 364aa76c..3b2f019d 100644 --- a/client/package-lock.json +++ b/client/package-lock.json @@ -30,6 +30,7 @@ "@babel/core": "7.26.10", "@babel/eslint-parser": "7.27.0", "@babel/preset-env": "7.26.9", + "@types/vuelidate": "^0.7.22", "@vitejs/plugin-vue2": "2.3.3", "eslint": "8.57.0", "eslint-config-airbnb-base": "15.0.0", @@ -2666,6 +2667,52 @@ "dev": true, "license": "MIT" }, + "node_modules/@types/vuelidate": { + "version": "0.7.22", + "resolved": "https://registry.npmjs.org/@types/vuelidate/-/vuelidate-0.7.22.tgz", + "integrity": "sha512-bD3pP9FgL3pxMVQ9NJ3d8BbV8Ij6xsrDKdCO4l1Wq/AksXxRRmQ9lmYjRJwn/hLMcgWO/k0QdULfZWpRz13adw==", + "dev": true, + "license": "MIT", + "dependencies": { + "vue": "^2.7.15" + } + }, + "node_modules/@types/vuelidate/node_modules/@vue/compiler-sfc": { + "version": "2.7.16", + "resolved": "https://registry.npmjs.org/@vue/compiler-sfc/-/compiler-sfc-2.7.16.tgz", + "integrity": "sha512-KWhJ9k5nXuNtygPU7+t1rX6baZeqOYLEforUPjgNDBnLicfHCoi48H87Q8XyLZOrNNsmhuwKqtpDQWjEFe6Ekg==", + "dev": true, + "dependencies": { + "@babel/parser": "^7.23.5", + "postcss": "^8.4.14", + "source-map": "^0.6.1" + }, + "optionalDependencies": { + "prettier": "^1.18.2 || ^2.0.0" + } + }, + "node_modules/@types/vuelidate/node_modules/source-map": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", + "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==", + "dev": true, + "license": "BSD-3-Clause", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/@types/vuelidate/node_modules/vue": { + "version": "2.7.16", + "resolved": "https://registry.npmjs.org/vue/-/vue-2.7.16.tgz", + "integrity": "sha512-4gCtFXaAA3zYZdTp5s4Hl2sozuySsgz4jy1EnpBHNfpMa9dK1ZCG7viqBPCwXtmgc8nHqUsAu3G4gtmXkkY3Sw==", + "deprecated": "Vue 2 has reached EOL and is no longer actively maintained. See https://v2.vuejs.org/eol/ for more details.", + "dev": true, + "license": "MIT", + "dependencies": { + "@vue/compiler-sfc": "2.7.16", + "csstype": "^3.1.0" + } + }, "node_modules/@ungap/structured-clone": { "version": "1.2.1", "resolved": "https://registry.npmjs.org/@ungap/structured-clone/-/structured-clone-1.2.1.tgz", @@ -7361,6 +7408,23 @@ "node": ">= 0.8.0" } }, + "node_modules/prettier": { + "version": "2.8.8", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-2.8.8.tgz", + "integrity": "sha512-tdN8qQGvNjw4CHbY+XXk0JgCXn9QiF21a55rBe5LJAU+kDyC4WQn4+awm2Xfk2lQMk5fKup9XgzTZtGkjBdP9Q==", + "dev": true, + "license": "MIT", + "optional": true, + "bin": { + "prettier": "bin-prettier.js" + }, + "engines": { + "node": ">=10.13.0" + }, + "funding": { + "url": "https://github.com/prettier/prettier?sponsor=1" + } + }, "node_modules/process-nextick-args": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/process-nextick-args/-/process-nextick-args-2.0.1.tgz", diff --git a/client/package.json b/client/package.json index e1ebe7ee..145d35bf 100644 --- a/client/package.json +++ b/client/package.json @@ -34,6 +34,7 @@ "@babel/core": "7.26.10", "@babel/eslint-parser": "7.27.0", "@babel/preset-env": "7.26.9", + "@types/vuelidate": "^0.7.22", "@vitejs/plugin-vue2": "2.3.3", "eslint": "8.57.0", "eslint-config-airbnb-base": "15.0.0", diff --git a/client/src/main.js b/client/src/main.js index 68ee773b..c4008699 100644 --- a/client/src/main.js +++ b/client/src/main.js @@ -64,6 +64,14 @@ Vue.filter('capitalize', (value) => { if (!value) return ''; return value.toString().split(' ').map((word) => word.charAt(0).toUpperCase() + word.slice(1)).join(' '); }); +Vue.filter('uppercase', (value) => { + if (!value) return ''; + return value.toString().toUpperCase(); +}); +Vue.filter('lowercase', (value) => { + if (!value) return ''; + return value.toString().toLowerCase(); +}); new Vue({ router, diff --git a/client/src/store/modules/script.js b/client/src/store/modules/script.js index 0437be64..66772997 100644 --- a/client/src/store/modules/script.js +++ b/client/src/store/modules/script.js @@ -10,6 +10,7 @@ export default { script: {}, cues: {}, cuts: [], + stageDirectionStyles: [], }, mutations: { SET_REVISIONS(state, revisions) { @@ -27,6 +28,9 @@ export default { SET_CUTS(state, cuts) { state.cuts = cuts; }, + SET_STAGE_DIRECTION_STYLES(state, styles) { + state.stageDirectionStyles = styles; + }, }, actions: { async GET_SCRIPT_REVISIONS(context) { @@ -216,6 +220,68 @@ export default { Vue.$toast.error('Unable to save script cuts'); } }, + async GET_STAGE_DIRECTION_STYLES(context) { + const response = await fetch(`${makeURL('/api/v1/show/script/stage_direction_styles')}`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + }, + }); + if (response.ok) { + const respJson = await response.json(); + context.commit('SET_STAGE_DIRECTION_STYLES', respJson.styles); + } else { + log.error('Unable to load stage direction styles'); + } + }, + async ADD_STAGE_DIRECTION_STYLE(context, style) { + const response = await fetch(`${makeURL('/api/v1/show/script/stage_direction_styles')}`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(style), + }); + if (response.ok) { + context.dispatch('GET_STAGE_DIRECTION_STYLES'); + Vue.$toast.success('Added new stage direction style!'); + } else { + log.error('Unable to add new stage direction style'); + Vue.$toast.error('Unable to add new stage direction style'); + } + }, + async DELETE_STAGE_DIRECTION_STYLE(context, styleId) { + const response = await fetch(`${makeURL('/api/v1/show/script/stage_direction_styles')}`, { + method: 'DELETE', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ id: styleId }), + }); + if (response.ok) { + context.dispatch('GET_STAGE_DIRECTION_STYLES'); + Vue.$toast.success('Deleted stage direction style!'); + } else { + log.error('Unable to delete stage direction style'); + Vue.$toast.error('Unable to delete stage direction style'); + } + }, + async UPDATE_STAGE_DIRECTION_STYLE(context, style) { + const response = await fetch(`${makeURL('/api/v1/show/script/stage_direction_styles')}`, { + method: 'PATCH', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(style), + }); + if (response.ok) { + context.dispatch('GET_STAGE_DIRECTION_STYLES'); + Vue.$toast.success('Updated stage direction style!'); + } else { + log.error('Unable to edit stage direction style'); + Vue.$toast.error('Unable to edit stage direction style'); + } + }, }, getters: { SCRIPT_REVISIONS(state) { @@ -237,5 +303,8 @@ export default { SCRIPT_CUTS(state) { return state.cuts; }, + STAGE_DIRECTION_STYLES(state) { + return state.stageDirectionStyles; + }, }, }; diff --git a/client/src/views/show/config/ConfigScript.vue b/client/src/views/show/config/ConfigScript.vue index caafb43a..7923f2a5 100644 --- a/client/src/views/show/config/ConfigScript.vue +++ b/client/src/views/show/config/ConfigScript.vue @@ -82,6 +82,9 @@ + + + @@ -134,10 +137,11 @@ import { mapActions, mapGetters } from 'vuex'; import { required } from 'vuelidate/lib/validators'; import ScriptConfig from '@/vue_components/show/config/script/ScriptEditor.vue'; +import StageDirectionStyles from '@/vue_components/show/config/script/StageDirectionStyles.vue'; export default { name: 'ConfigScript', - components: { ScriptConfig }, + components: { ScriptConfig, StageDirectionConfigs: StageDirectionStyles }, data() { return { revisionColumns: [ diff --git a/client/src/vue_components/show/config/script/StageDirectionStyles.vue b/client/src/vue_components/show/config/script/StageDirectionStyles.vue new file mode 100644 index 00000000..ac9b425f --- /dev/null +++ b/client/src/vue_components/show/config/script/StageDirectionStyles.vue @@ -0,0 +1,566 @@ + + + diff --git a/server/alembic_config/versions/a44e01459595_add_stage_direction_styles.py b/server/alembic_config/versions/a44e01459595_add_stage_direction_styles.py new file mode 100644 index 00000000..56df9ac7 --- /dev/null +++ b/server/alembic_config/versions/a44e01459595_add_stage_direction_styles.py @@ -0,0 +1,57 @@ +"""Add stage direction styles + +Revision ID: a44e01459595 +Revises: d4f66f58158b +Create Date: 2025-04-13 23:19:47.362110 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'a44e01459595' +down_revision: Union[str, None] = 'd4f66f58158b' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('stage_direction_styles', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('script_id', sa.Integer(), nullable=True), + sa.Column('description', sa.String(), nullable=True), + sa.Column('bold', sa.Boolean(), nullable=True), + sa.Column('italic', sa.Boolean(), nullable=True), + sa.Column('underline', sa.Boolean(), nullable=True), + sa.Column('text_format', sa.String(), nullable=True), + sa.Column('text_colour', sa.String(), nullable=True), + sa.Column('enable_background_colour', sa.Boolean(), nullable=True), + sa.Column('background_colour', sa.String(), nullable=True), + sa.ForeignKeyConstraint(['script_id'], ['script.id'], name=op.f('fk_stage_direction_styles_script_id_script')), + sa.PrimaryKeyConstraint('id', name=op.f('pk_stage_direction_styles')) + ) + with op.batch_alter_table('stage_direction_styles', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_stage_direction_styles_script_id'), ['script_id'], unique=False) + + with op.batch_alter_table('script_lines', schema=None) as batch_op: + batch_op.add_column(sa.Column('stage_direction_style_id', sa.Integer(), nullable=True)) + batch_op.create_foreign_key(batch_op.f('fk_script_lines_stage_direction_style_id_stage_direction_styles'), 'stage_direction_styles', ['stage_direction_style_id'], ['id'], ondelete='SET NULL') + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('script_lines', schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f('fk_script_lines_stage_direction_style_id_stage_direction_styles'), type_='foreignkey') + batch_op.drop_column('stage_direction_style_id') + + with op.batch_alter_table('stage_direction_styles', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('ix_stage_direction_styles_script_id')) + + op.drop_table('stage_direction_styles') + # ### end Alembic commands ### diff --git a/server/controllers/api/show/script.py b/server/controllers/api/show/script.py index 1aeed387..5bf66d98 100644 --- a/server/controllers/api/show/script.py +++ b/server/controllers/api/show/script.py @@ -6,11 +6,11 @@ from models.cue import CueAssociation from models.script import (Script, ScriptRevision, ScriptLine, ScriptLineRevisionAssociation, - ScriptLinePart, ScriptCuts) + ScriptLinePart, ScriptCuts, StageDirectionStyle) from models.show import Show from models.session import Session from rbac.role import Role -from schemas.schemas import ScriptRevisionsSchema, ScriptLineSchema +from schemas.schemas import ScriptRevisionsSchema, ScriptLineSchema, StageDirectionStyleSchema from utils.web.base_controller import BaseAPIController from utils.web.web_decorators import requires_show, no_live_session from utils.web.route import ApiRoute, ApiVersion @@ -897,3 +897,195 @@ def get(self): self.set_status(404) self.finish({'message': '404 show not found'}) return + +@ApiRoute('/show/script/stage_direction_styles', ApiVersion.V1) +class StageDirectionStylesController(BaseAPIController): + @requires_show + def get(self): + current_show = self.get_current_show() + show_id = current_show['id'] + + stage_direction_style_schema = StageDirectionStyleSchema() + with self.make_session() as session: + show = session.query(Show).get(show_id) + if show: + script: Script = session.query(Script).filter(Script.show_id == show.id).first() + stage_direction_styles = [stage_direction_style_schema.dump(style) for style in script.stage_direction_styles] + + self.set_status(200) + self.finish({'styles': stage_direction_styles}) + else: + self.set_status(404) + self.finish({'message': '404 show not found'}) + return + + @requires_show + @no_live_session + async def post(self): + current_show = self.get_current_show() + show_id = current_show['id'] + + with self.make_session() as session: + show = session.query(Show).get(show_id) + if show: + script: Script = session.query(Script).filter(Script.show_id == show.id).first() + self.requires_role(script, Role.WRITE) + data = escape.json_decode(self.request.body) + + description: str = data.get('description', None) + if not description: + self.set_status(400) + await self.finish({'message': 'Description missing'}) + return + + bold: bool = data.get('bold', False) + italic: bool = data.get('italic', False) + underline: bool = data.get('underline', False) + + text_format: str = data.get('textFormat', None) + if not text_format or text_format not in ['default', 'upper', 'lower']: + self.set_status(400) + await self.finish({'message': 'Text format missing or invalid'}) + return + + text_colour: str = data.get('textColour', None) + if not text_colour: + self.set_status(400) + await self.finish({'message': 'Text colour missing'}) + return + + enable_background_colour: bool = data.get('enableBackgroundColour', False) + background_colour: str = data.get('backgroundColour', None) + if enable_background_colour and not background_colour: + self.set_status(400) + await self.finish({'message': 'Background colour missing'}) + return + + new_style = StageDirectionStyle( + script_id=script.id, + description=description, + bold=bold, + italic=italic, + underline=underline, + text_format=text_format, + text_colour=text_colour, + enable_background_colour=enable_background_colour, + background_colour=background_colour, + ) + session.add(new_style) + session.commit() + + self.set_status(200) + await self.finish({'id': new_style.id, 'message': 'Successfully added stage direction style'}) + + await self.application.ws_send_to_all('NOOP', 'GET_STAGE_DIRECTION_STYLES', {}) + else: + self.set_status(404) + await self.finish({'message': '404 show not found'}) + + @requires_show + @no_live_session + async def patch(self): + current_show = self.get_current_show() + show_id = current_show['id'] + + with self.make_session() as session: + show: Show = session.query(Show).get(show_id) + if show: + script: Script = session.query(Script).filter(Script.show_id == show.id).first() + self.requires_role(script, Role.WRITE) + data = escape.json_decode(self.request.body) + + style_id = data.get('id', None) + if not style_id: + self.set_status(400) + await self.finish({'message': 'ID missing'}) + return + + style: StageDirectionStyle = session.query(StageDirectionStyle).get(style_id) + if not style: + self.set_status(404) + await self.finish({'message': '404 stage direction style not found'}) + return + + description: str = data.get('description', None) + if not description: + self.set_status(400) + await self.finish({'message': 'Description missing'}) + return + + bold: bool = data.get('bold', False) + italic: bool = data.get('italic', False) + underline: bool = data.get('underline', False) + + text_format: str = data.get('textFormat', None) + if not text_format or text_format not in ['default', 'upper', 'lower']: + self.set_status(400) + await self.finish({'message': 'Text format missing or invalid'}) + return + + text_colour: str = data.get('textColour', None) + if not text_colour: + self.set_status(400) + await self.finish({'message': 'Text colour missing'}) + return + + enable_background_colour: bool = data.get('enableBackgroundColour', False) + background_colour: str = data.get('backgroundColour', None) + if enable_background_colour and not background_colour: + self.set_status(400) + await self.finish({'message': 'Background colour missing'}) + return + + style.description = description + style.bold = bold + style.italic = italic + style.underline = underline + style.text_format = text_format + style.text_colour = text_colour + style.enable_background_colour = enable_background_colour + style.background_colour = background_colour + session.commit() + + self.set_status(200) + await self.finish({'message': 'Successfully edited stage direction style'}) + + await self.application.ws_send_to_all('NOOP', 'GET_STAGE_DIRECTION_STYLES', {}) + else: + self.set_status(404) + await self.finish({'message': '404 show not found'}) + + @requires_show + @no_live_session + async def delete(self): + current_show = self.get_current_show() + show_id = current_show['id'] + + with self.make_session() as session: + show: Show = session.query(Show).get(show_id) + if show: + script: Script = session.query(Script).filter(Script.show_id == show.id).first() + self.requires_role(script, Role.WRITE) + data = escape.json_decode(self.request.body) + + style_id = data.get('id', None) + if not style_id: + self.set_status(400) + await self.finish({'message': 'ID missing'}) + return + + entry: StageDirectionStyle = session.get(StageDirectionStyle, style_id) + if entry: + session.delete(entry) + session.commit() + + self.set_status(200) + await self.finish({'message': 'Successfully deleted stage direction style'}) + + await self.application.ws_send_to_all('NOOP', 'GET_STAGE_DIRECTION_STYLES', {}) + else: + self.set_status(404) + await self.finish({'message': '404 stage direction style not found'}) + else: + self.set_status(404) + await self.finish({'message': '404 show not found'}) diff --git a/server/models/script.py b/server/models/script.py index cf08dd4d..d9bea858 100644 --- a/server/models/script.py +++ b/server/models/script.py @@ -42,9 +42,11 @@ class ScriptLine(db.Model): scene_id = Column(Integer, ForeignKey('scene.id')) page = Column(Integer, index=True) stage_direction = Column(Boolean) + stage_direction_style_id = Column(Integer, ForeignKey('stage_direction_styles.id', ondelete='SET NULL')) act = relationship('Act', uselist=False, back_populates='lines') scene = relationship('Scene', uselist=False, back_populates='lines') + stage_direction_style = relationship('StageDirectionStyle', uselist=False) class ScriptLineRevisionAssociation(db.Model, DeleteMixin): @@ -112,3 +114,22 @@ class ScriptCuts(db.Model): revision = relationship('ScriptRevision', uselist=False, foreign_keys=[revision_id], backref=backref('line_part_cuts', uselist=True, cascade='all, delete-orphan')) + + +class StageDirectionStyle(db.Model): + __tablename__ = 'stage_direction_styles' + + id = Column(Integer, primary_key=True, autoincrement=True) + script_id = Column(Integer, ForeignKey('script.id'), index=True) + + description = Column(String) + bold = Column(Boolean) + italic = Column(Boolean) + underline = Column(Boolean) + text_format = Column(String) + text_colour = Column(String) + enable_background_colour = Column(Boolean) + background_colour = Column(String) + + script=relationship('Script', uselist=False, foreign_keys=[script_id], + backref=backref('stage_direction_styles', uselist=True, cascade='all, delete-orphan')) diff --git a/server/schemas/schemas.py b/server/schemas/schemas.py index 97a22d28..8625832b 100644 --- a/server/schemas/schemas.py +++ b/server/schemas/schemas.py @@ -6,7 +6,7 @@ from models.cue import CueType, Cue from models.mics import Microphone, MicrophoneAllocation from models.models import db -from models.script import ScriptRevision, ScriptLine, ScriptLinePart, Script, ScriptCuts +from models.script import ScriptRevision, ScriptLine, ScriptLinePart, Script, ScriptCuts, StageDirectionStyle from models.show import Show, Cast, Character, CharacterGroup, Act, Scene from models.session import Session, ShowSession from models.user import User @@ -169,6 +169,14 @@ class Meta: include_fk = True +@schema +class StageDirectionStyleSchema(SQLAlchemyAutoSchema): + class Meta: + model = StageDirectionStyle + load_instance = True + include_fk = True + + @schema class ShowSessionSchema(SQLAlchemyAutoSchema): class Meta: diff --git a/server/utils/database.py b/server/utils/database.py index cbe19364..e985a70a 100644 --- a/server/utils/database.py +++ b/server/utils/database.py @@ -1,7 +1,8 @@ import functools +from sqlalchemy import MetaData -from sqlalchemy.orm import sessionmaker -from tornado_sqlalchemy import SQLAlchemy, SessionEx +from sqlalchemy.orm import sessionmaker, declarative_base +from tornado_sqlalchemy import SQLAlchemy, SessionEx, BindMeta class DeleteMixin: @@ -39,3 +40,14 @@ def get_mapper_for_table(self, tablename): if mapper.mapped_table.fullname == tablename: return mapper.entity return None + + def make_declarative_base(self): + convention = { + "ix": 'ix_%(column_0_label)s', + "uq": "uq_%(table_name)s_%(column_0_name)s", + "ck": "ck_%(table_name)s_%(constraint_name)s", + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "pk": "pk_%(table_name)s" + } + metadata = MetaData(naming_convention=convention) + return declarative_base(metaclass=BindMeta, metadata=metadata) From 83fa210775fb74f4a840ff923adcb0397f036898 Mon Sep 17 00:00:00 2001 From: Tim Bradgate Date: Mon, 14 Apr 2025 00:52:04 +0100 Subject: [PATCH 02/11] Format all the python files --- server/alembic_config/env.py | 6 +- ...a44e01459595_add_stage_direction_styles.py | 80 +- .../d4f66f58158b_initial_alembic_revision.py | 8 +- server/controllers/api/auth.py | 124 ++- server/controllers/api/rbac.py | 156 +-- server/controllers/api/settings.py | 13 +- server/controllers/api/show/acts.py | 135 +-- server/controllers/api/show/cast.py | 76 +- server/controllers/api/show/characters.py | 217 ++-- server/controllers/api/show/cues.py | 231 +++-- server/controllers/api/show/microphones.py | 166 ++-- server/controllers/api/show/scenes.py | 123 ++- server/controllers/api/show/script.py | 934 +++++++++++------- server/controllers/api/show/sessions.py | 68 +- server/controllers/api/show/shows.py | 101 +- server/controllers/api/websocket.py | 4 +- server/controllers/controllers.py | 55 +- server/controllers/ws_controller.py | 267 ++--- server/digi_server/app_server.py | 166 ++-- server/digi_server/logger.py | 42 +- server/digi_server/settings.py | 261 +++-- server/main.py | 31 +- server/models/cue.py | 71 +- server/models/mics.py | 44 +- server/models/models.py | 4 +- server/models/script.py | 170 ++-- server/models/session.py | 33 +- server/models/show.py | 139 +-- server/models/user.py | 6 +- server/rbac/rbac.py | 14 +- server/rbac/rbac_db.py | 140 ++- server/schemas/schemas.py | 23 +- server/test/test_auth_api.py | 235 +++-- server/test/test_digi_server.py | 18 +- server/test/test_settings.py | 22 +- server/test/test_utils.py | 18 +- server/utils/database.py | 22 +- server/utils/file_watcher.py | 17 +- server/utils/pkg_utils.py | 9 +- server/utils/singleton.py | 2 +- server/utils/tree.py | 6 +- server/utils/web/base_controller.py | 57 +- server/utils/web/route.py | 13 +- server/utils/web/web_decorators.py | 19 +- 44 files changed, 2594 insertions(+), 1752 deletions(-) diff --git a/server/alembic_config/env.py b/server/alembic_config/env.py index d14d0503..4be3a1c7 100644 --- a/server/alembic_config/env.py +++ b/server/alembic_config/env.py @@ -2,10 +2,8 @@ import os from logging.config import fileConfig -from sqlalchemy import engine_from_config -from sqlalchemy import pool - from alembic import context +from sqlalchemy import engine_from_config, pool from models import models @@ -93,7 +91,7 @@ def run_migrations_online() -> None: target_metadata=target_metadata, include_schemas=False, include_name=include_name, - render_as_batch=True + render_as_batch=True, ) with context.begin_transaction(): diff --git a/server/alembic_config/versions/a44e01459595_add_stage_direction_styles.py b/server/alembic_config/versions/a44e01459595_add_stage_direction_styles.py index 56df9ac7..d64708c8 100644 --- a/server/alembic_config/versions/a44e01459595_add_stage_direction_styles.py +++ b/server/alembic_config/versions/a44e01459595_add_stage_direction_styles.py @@ -5,53 +5,77 @@ Create Date: 2025-04-13 23:19:47.362110 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'a44e01459595' -down_revision: Union[str, None] = 'd4f66f58158b' +revision: str = "a44e01459595" +down_revision: Union[str, None] = "d4f66f58158b" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_table('stage_direction_styles', - sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), - sa.Column('script_id', sa.Integer(), nullable=True), - sa.Column('description', sa.String(), nullable=True), - sa.Column('bold', sa.Boolean(), nullable=True), - sa.Column('italic', sa.Boolean(), nullable=True), - sa.Column('underline', sa.Boolean(), nullable=True), - sa.Column('text_format', sa.String(), nullable=True), - sa.Column('text_colour', sa.String(), nullable=True), - sa.Column('enable_background_colour', sa.Boolean(), nullable=True), - sa.Column('background_colour', sa.String(), nullable=True), - sa.ForeignKeyConstraint(['script_id'], ['script.id'], name=op.f('fk_stage_direction_styles_script_id_script')), - sa.PrimaryKeyConstraint('id', name=op.f('pk_stage_direction_styles')) + op.create_table( + "stage_direction_styles", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("script_id", sa.Integer(), nullable=True), + sa.Column("description", sa.String(), nullable=True), + sa.Column("bold", sa.Boolean(), nullable=True), + sa.Column("italic", sa.Boolean(), nullable=True), + sa.Column("underline", sa.Boolean(), nullable=True), + sa.Column("text_format", sa.String(), nullable=True), + sa.Column("text_colour", sa.String(), nullable=True), + sa.Column("enable_background_colour", sa.Boolean(), nullable=True), + sa.Column("background_colour", sa.String(), nullable=True), + sa.ForeignKeyConstraint( + ["script_id"], + ["script.id"], + name=op.f("fk_stage_direction_styles_script_id_script"), + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_stage_direction_styles")), ) - with op.batch_alter_table('stage_direction_styles', schema=None) as batch_op: - batch_op.create_index(batch_op.f('ix_stage_direction_styles_script_id'), ['script_id'], unique=False) + with op.batch_alter_table("stage_direction_styles", schema=None) as batch_op: + batch_op.create_index( + batch_op.f("ix_stage_direction_styles_script_id"), + ["script_id"], + unique=False, + ) - with op.batch_alter_table('script_lines', schema=None) as batch_op: - batch_op.add_column(sa.Column('stage_direction_style_id', sa.Integer(), nullable=True)) - batch_op.create_foreign_key(batch_op.f('fk_script_lines_stage_direction_style_id_stage_direction_styles'), 'stage_direction_styles', ['stage_direction_style_id'], ['id'], ondelete='SET NULL') + with op.batch_alter_table("script_lines", schema=None) as batch_op: + batch_op.add_column( + sa.Column("stage_direction_style_id", sa.Integer(), nullable=True) + ) + batch_op.create_foreign_key( + batch_op.f( + "fk_script_lines_stage_direction_style_id_stage_direction_styles" + ), + "stage_direction_styles", + ["stage_direction_style_id"], + ["id"], + ondelete="SET NULL", + ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('script_lines', schema=None) as batch_op: - batch_op.drop_constraint(batch_op.f('fk_script_lines_stage_direction_style_id_stage_direction_styles'), type_='foreignkey') - batch_op.drop_column('stage_direction_style_id') + with op.batch_alter_table("script_lines", schema=None) as batch_op: + batch_op.drop_constraint( + batch_op.f( + "fk_script_lines_stage_direction_style_id_stage_direction_styles" + ), + type_="foreignkey", + ) + batch_op.drop_column("stage_direction_style_id") - with op.batch_alter_table('stage_direction_styles', schema=None) as batch_op: - batch_op.drop_index(batch_op.f('ix_stage_direction_styles_script_id')) + with op.batch_alter_table("stage_direction_styles", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("ix_stage_direction_styles_script_id")) - op.drop_table('stage_direction_styles') + op.drop_table("stage_direction_styles") # ### end Alembic commands ### diff --git a/server/alembic_config/versions/d4f66f58158b_initial_alembic_revision.py b/server/alembic_config/versions/d4f66f58158b_initial_alembic_revision.py index 19ed2c28..1b2cd1db 100644 --- a/server/alembic_config/versions/d4f66f58158b_initial_alembic_revision.py +++ b/server/alembic_config/versions/d4f66f58158b_initial_alembic_revision.py @@ -1,18 +1,18 @@ """Initial Alembic Revision Revision ID: d4f66f58158b -Revises: +Revises: Create Date: 2024-06-02 15:50:23.550851 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'd4f66f58158b' +revision: str = "d4f66f58158b" down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/server/controllers/api/auth.py b/server/controllers/api/auth.py index e7abafee..65da1ced 100644 --- a/server/controllers/api/auth.py +++ b/server/controllers/api/auth.py @@ -13,39 +13,41 @@ from utils.web.web_decorators import require_admin, requires_show -@ApiRoute('auth/create', ApiVersion.V1) +@ApiRoute("auth/create", ApiVersion.V1) class UserCreateController(BaseAPIController): async def post(self): data = escape.json_decode(self.request.body) - username = data.get('username', '') + username = data.get("username", "") if not username: self.set_status(400) - await self.finish({'message': 'Username missing'}) + await self.finish({"message": "Username missing"}) return - password = data.get('password', '') + password = data.get("password", "") if not password: self.set_status(400) - await self.finish({'message': 'Password missing'}) + await self.finish({"message": "Password missing"}) return if len(password) < 6: self.set_status(400) - await self.finish({'message': 'Password must be at least 6 characters long'}) + await self.finish( + {"message": "Password must be at least 6 characters long"} + ) return - show_id = data.get('show_id', None) - is_admin = data.get('is_admin', False) + show_id = data.get("show_id", None) + is_admin = data.get("is_admin", False) if not show_id and not is_admin: self.set_status(400) - await self.finish({'message': 'Non admin user requires a show allocation'}) + await self.finish({"message": "Non admin user requires a show allocation"}) return if is_admin and show_id: self.set_status(400) - await self.finish({'message': 'Admin user cannot have a show allocation'}) + await self.finish({"message": "Admin user cannot have a show allocation"}) return with self.make_session() as session: @@ -53,69 +55,77 @@ async def post(self): show = session.query(Show).get(show_id) if not show: self.set_status(400) - await self.finish({'message': 'Show not found'}) + await self.finish({"message": "Show not found"}) return - conflict_user = session.query(User).filter(User.username == username).first() + conflict_user = ( + session.query(User).filter(User.username == username).first() + ) if conflict_user: self.set_status(400) - await self.finish({'message': 'Username already taken'}) + await self.finish({"message": "Username already taken"}) return hashed_password = await IOLoop.current().run_in_executor( - None, bcrypt.hashpw, escape.utf8(password), bcrypt.gensalt()) + None, bcrypt.hashpw, escape.utf8(password), bcrypt.gensalt() + ) hashed_password = escape.to_unicode(hashed_password) - session.add(User( - username=username, - password=hashed_password, - show_id=show_id, - is_admin=is_admin)) + session.add( + User( + username=username, + password=hashed_password, + show_id=show_id, + is_admin=is_admin, + ) + ) session.commit() if is_admin: - await self.application.digi_settings.set('has_admin_user', True) + await self.application.digi_settings.set("has_admin_user", True) self.set_status(200) - await self.application.ws_send_to_all('NOOP', 'GET_USERS', {}) - await self.finish({'message': 'Successfully created user'}) + await self.application.ws_send_to_all("NOOP", "GET_USERS", {}) + await self.finish({"message": "Successfully created user"}) -@ApiRoute('auth/login', ApiVersion.V1) +@ApiRoute("auth/login", ApiVersion.V1) class LoginHandler(BaseAPIController): async def post(self): data = escape.json_decode(self.request.body) - username = data.get('username', '') + username = data.get("username", "") if not username: self.set_status(400) - await self.finish({'message': 'Username missing'}) + await self.finish({"message": "Username missing"}) return - password = data.get('password', '') + password = data.get("password", "") if not password: self.set_status(400) - await self.finish({'message': 'Password missing'}) + await self.finish({"message": "Password missing"}) return with self.make_session() as session: user = session.query(User).filter(User.username == username).first() if not user: self.set_status(401) - await self.finish({'message': 'Invalid username/password'}) + await self.finish({"message": "Invalid username/password"}) return if not user.is_admin: if not self.get_current_show(): self.set_status(403) - await self.finish({ - 'message': 'Non admin user cannot log in without a loaded show' - }) + await self.finish( + { + "message": "Non admin user cannot log in without a loaded show" + } + ) return - if user.show_id != self.get_current_show()['id']: + if user.show_id != self.get_current_show()["id"]: self.set_status(403) - await self.finish({'message': 'Loaded show does not match user'}) + await self.finish({"message": "Loaded show does not match user"}) return password_equal = await IOLoop.current().run_in_executor( @@ -126,7 +136,7 @@ async def post(self): ) if password_equal: - session_id = data.get('session_id', '') + session_id = data.get("session_id", "") if session_id: ws_session: Session = session.query(Session).get(session_id) if ws_session: @@ -134,22 +144,22 @@ async def post(self): user.last_login = datetime.utcnow() session.commit() - self.set_secure_cookie('digiscript_user_id', str(user.id)) + self.set_secure_cookie("digiscript_user_id", str(user.id)) self.set_status(200) - await self.finish({'message': 'Successful log in'}) + await self.finish({"message": "Successful log in"}) else: self.set_status(401) - await self.finish({'message': 'Invalid username/password'}) + await self.finish({"message": "Invalid username/password"}) -@ApiRoute('auth/logout', ApiVersion.V1) +@ApiRoute("auth/logout", ApiVersion.V1) class LogoutHandler(BaseAPIController): @web.authenticated async def post(self): data = escape.json_decode(self.request.body) if self.current_user: - session_id = data.get('session_id', '') + session_id = data.get("session_id", "") if session_id: with self.make_session() as session: ws_session: Session = session.query(Session).get(session_id) @@ -157,30 +167,33 @@ async def post(self): ws_session.user = None session.commit() - self.clear_cookie('digiscript_user_id') + self.clear_cookie("digiscript_user_id") self.set_status(200) - await self.finish({'message': 'Successfully logged out'}) + await self.finish({"message": "Successfully logged out"}) else: self.set_status(401) - await self.finish({'message': 'No user logged in'}) + await self.finish({"message": "No user logged in"}) -@ApiRoute('auth/validate', ApiVersion.V1) +@ApiRoute("auth/validate", ApiVersion.V1) class AuthValidationHandler(BaseAPIController): @web.authenticated async def get(self): - if self.current_user['is_admin']: + if self.current_user["is_admin"]: self.set_status(200) - await self.finish({'message': 'OK'}) - elif self.current_show and self.current_user['show_id'] == self.current_show['id']: + await self.finish({"message": "OK"}) + elif ( + self.current_show + and self.current_user["show_id"] == self.current_show["id"] + ): self.set_status(200) - await self.finish({'message': 'OK'}) + await self.finish({"message": "OK"}) else: self.set_status(401) - self.write({'message': 'Not Authenticated'}) + self.write({"message": "Not Authenticated"}) -@ApiRoute('auth/users', ApiVersion.V1) +@ApiRoute("auth/users", ApiVersion.V1) class UsersHandler(BaseAPIController): @web.authenticated @require_admin @@ -188,13 +201,18 @@ class UsersHandler(BaseAPIController): def get(self): user_schema = UserSchema() with self.make_session() as session: - users = session.query(User).filter( - (User.show_id == self.get_current_show()['id']) | (User.is_admin)).all() + users = ( + session.query(User) + .filter( + (User.show_id == self.get_current_show()["id"]) | (User.is_admin) + ) + .all() + ) self.set_status(200) - self.finish({'users': [user_schema.dump(u) for u in users]}) + self.finish({"users": [user_schema.dump(u) for u in users]}) -@ApiRoute('/auth', ApiVersion.V1) +@ApiRoute("/auth", ApiVersion.V1) class AuthHandler(BaseAPIController): @web.authenticated def get(self): diff --git a/server/controllers/api/rbac.py b/server/controllers/api/rbac.py index 446d81f5..09325bc7 100644 --- a/server/controllers/api/rbac.py +++ b/server/controllers/api/rbac.py @@ -1,7 +1,7 @@ from collections import defaultdict from sqlalchemy import inspect -from tornado import web, escape +from tornado import escape, web from models.user import User from rbac.role import Role @@ -11,14 +11,16 @@ from utils.web.web_decorators import require_admin -@ApiRoute('rbac/roles', ApiVersion.V1) +@ApiRoute("rbac/roles", ApiVersion.V1) class RBACRolesHandler(BaseAPIController): async def get(self): self.set_status(200) - await self.finish({'roles': [{'key': role.name, 'value': role.value} for role in Role]}) + await self.finish( + {"roles": [{"key": role.name, "value": role.value} for role in Role]} + ) -@ApiRoute('rbac/user/resources', ApiVersion.V1) +@ApiRoute("rbac/user/resources", ApiVersion.V1) class RBACUsersHandler(BaseAPIController): @web.authenticated @require_admin @@ -29,185 +31,201 @@ async def get(self): r_inspect = inspect(resource) res.append(r_inspect.mapped_table.fullname) self.set_status(200) - await self.finish({'resources': res}) + await self.finish({"resources": res}) -@ApiRoute('rbac/user/objects', ApiVersion.V1) +@ApiRoute("rbac/user/objects", ApiVersion.V1) class RBACObjectsHandler(BaseAPIController): @web.authenticated @require_admin async def get(self): - resource = self.get_query_argument('resource', None) - user = self.get_query_argument('user', None) + resource = self.get_query_argument("resource", None) + user = self.get_query_argument("user", None) if not resource: self.set_status(400) - await self.finish({'message': 'resource query parameter not fulfilled'}) + await self.finish({"message": "resource query parameter not fulfilled"}) return mapper = self.application.get_db().get_mapper_for_table(resource) if not mapper: self.set_status(404) - await self.finish({'message': 'object not found'}) + await self.finish({"message": "object not found"}) return objects = self.application.rbac.get_objects_for_resource(mapper) if not user: self.set_status(200) - await self.finish({ - 'objects': [get_registry().get_schema_by_model(o.__class__)().dump(o) for o in objects], - 'display_fields': self.application.rbac.get_display_fields(mapper) - }) + await self.finish( + { + "objects": [ + get_registry().get_schema_by_model(o.__class__)().dump(o) + for o in objects + ], + "display_fields": self.application.rbac.get_display_fields(mapper), + } + ) else: with self.make_session() as session: user = session.query(User).get(int(user)) if not user: self.set_status(404) - await self.finish({'message': 'user not found'}) + await self.finish({"message": "user not found"}) return self.set_status(200) - await self.finish({ - 'objects': [(get_registry().get_schema_by_model(o.__class__)().dump(o), - self.application.rbac.get_roles(user, o).value) for o in objects], - 'display_fields': self.application.rbac.get_display_fields(mapper) - }) - - -@ApiRoute('rbac/user/roles', ApiVersion.V1) + await self.finish( + { + "objects": [ + ( + get_registry() + .get_schema_by_model(o.__class__)() + .dump(o), + self.application.rbac.get_roles(user, o).value, + ) + for o in objects + ], + "display_fields": self.application.rbac.get_display_fields( + mapper + ), + } + ) + + +@ApiRoute("rbac/user/roles", ApiVersion.V1) class RBACUserRolesHandler(BaseAPIController): @web.authenticated async def get(self): with self.make_session() as session: res = defaultdict(list) - user = session.query(User).get(self.current_user['id']) + user = session.query(User).get(self.current_user["id"]) roles = self.application.rbac.get_all_roles(user) for resource in roles: for role in roles[resource]: - res[resource].append([ - get_registry().get_schema_by_model(role[0].__class__)().dump(role[0]), - role[1].value - ]) + res[resource].append( + [ + get_registry() + .get_schema_by_model(role[0].__class__)() + .dump(role[0]), + role[1].value, + ] + ) self.set_status(200) - await self.finish({'roles': res}) + await self.finish({"roles": res}) -@ApiRoute('rbac/user/roles/grant', ApiVersion.V1) +@ApiRoute("rbac/user/roles/grant", ApiVersion.V1) class RBACRolesGrantHandler(BaseAPIController): @web.authenticated @require_admin async def post(self): data = escape.json_decode(self.request.body) - resource = data.get('resource', None) - rbac_object = data.get('object', None) - user = data.get('user', None) - role = data.get('role', None) + resource = data.get("resource", None) + rbac_object = data.get("object", None) + user = data.get("user", None) + role = data.get("role", None) if not resource: self.set_status(400) - await self.finish({'message': 'resource body parameter not fulfilled'}) + await self.finish({"message": "resource body parameter not fulfilled"}) return if not rbac_object: self.set_status(400) - await self.finish({'message': 'object body parameter not fulfilled'}) + await self.finish({"message": "object body parameter not fulfilled"}) return if not user: self.set_status(400) - await self.finish({'message': 'user body parameter not fulfilled'}) + await self.finish({"message": "user body parameter not fulfilled"}) return if not role: self.set_status(400) - await self.finish({'message': 'role body parameter not fulfilled'}) + await self.finish({"message": "role body parameter not fulfilled"}) return mapper = self.application.get_db().get_mapper_for_table(resource) if not mapper: self.set_status(404) - await self.finish({'message': 'resource not found'}) + await self.finish({"message": "resource not found"}) return with self.make_session() as session: user = session.query(User).get(int(user)) if not user: self.set_status(404) - await self.finish({'message': 'user not found'}) + await self.finish({"message": "user not found"}) return resource = inspect(mapper) cols = {} - cols.update({ - col.key: rbac_object.get(col.key) for col in resource.primary_key - }) + cols.update( + {col.key: rbac_object.get(col.key) for col in resource.primary_key} + ) rbac_object = session.query(mapper).get(cols) if not rbac_object: self.set_status(404) - await self.finish({'message': 'object not found'}) + await self.finish({"message": "object not found"}) return self.application.rbac.give_role(user, rbac_object, Role(role)) for socket in self.application.get_all_ws(user.id): - await socket.write_message({ - 'OP': 'NOOP', - 'DATA': {}, - 'ACTION': 'GET_CURRENT_RBAC' - }) + await socket.write_message( + {"OP": "NOOP", "DATA": {}, "ACTION": "GET_CURRENT_RBAC"} + ) -@ApiRoute('rbac/user/roles/revoke', ApiVersion.V1) +@ApiRoute("rbac/user/roles/revoke", ApiVersion.V1) class RBACRolesRevokeHandler(BaseAPIController): @web.authenticated @require_admin async def post(self): data = escape.json_decode(self.request.body) - resource = data.get('resource', None) - rbac_object = data.get('object', None) - user = data.get('user', None) - role = data.get('role', None) + resource = data.get("resource", None) + rbac_object = data.get("object", None) + user = data.get("user", None) + role = data.get("role", None) if not resource: self.set_status(400) - await self.finish({'message': 'resource body parameter not fulfilled'}) + await self.finish({"message": "resource body parameter not fulfilled"}) return if not rbac_object: self.set_status(400) - await self.finish({'message': 'object body parameter not fulfilled'}) + await self.finish({"message": "object body parameter not fulfilled"}) return if not user: self.set_status(400) - await self.finish({'message': 'user body parameter not fulfilled'}) + await self.finish({"message": "user body parameter not fulfilled"}) return if not role: self.set_status(400) - await self.finish({'message': 'role body parameter not fulfilled'}) + await self.finish({"message": "role body parameter not fulfilled"}) return mapper = self.application.get_db().get_mapper_for_table(resource) if not mapper: self.set_status(404) - await self.finish({'message': 'resource not found'}) + await self.finish({"message": "resource not found"}) return with self.make_session() as session: user = session.query(User).get(int(user)) if not user: self.set_status(404) - await self.finish({'message': 'user not found'}) + await self.finish({"message": "user not found"}) return resource = inspect(mapper) cols = {} - cols.update({ - col.key: rbac_object.get(col.key) for col in resource.primary_key - }) + cols.update( + {col.key: rbac_object.get(col.key) for col in resource.primary_key} + ) rbac_object = session.query(mapper).get(cols) if not rbac_object: self.set_status(404) - await self.finish({'message': 'object not found'}) + await self.finish({"message": "object not found"}) return self.application.rbac.revoke_role(user, rbac_object, Role(role)) for socket in self.application.get_all_ws(user.id): - await socket.write_message({ - 'OP': 'NOOP', - 'DATA': {}, - 'ACTION': 'GET_CURRENT_RBAC' - }) + await socket.write_message( + {"OP": "NOOP", "DATA": {}, "ACTION": "GET_CURRENT_RBAC"} + ) diff --git a/server/controllers/api/settings.py b/server/controllers/api/settings.py index 49f255ff..34485abb 100644 --- a/server/controllers/api/settings.py +++ b/server/controllers/api/settings.py @@ -7,7 +7,7 @@ from utils.web.web_decorators import no_live_session, require_admin -@ApiRoute('settings', ApiVersion.V1) +@ApiRoute("settings", ApiVersion.V1) class SettingsController(BaseAPIController): async def get(self): settings: Settings = self.application.digi_settings @@ -21,22 +21,21 @@ async def patch(self): settings: Settings = self.application.digi_settings data = escape.json_decode(self.request.body) - get_logger().debug(f'New settings data patched: {data}') + get_logger().debug(f"New settings data patched: {data}") for k, v in data.items(): await settings.set(k, v) settings_json = await settings.as_json() await self.application.ws_send_to_all( - 'SETTINGS_CHANGED', - 'WS_SETTINGS_CHANGED', - settings_json) + "SETTINGS_CHANGED", "WS_SETTINGS_CHANGED", settings_json + ) self.set_status(200) - self.write({'message': 'Settings updated'}) + self.write({"message": "Settings updated"}) -@ApiRoute('settings/raw', ApiVersion.V1) +@ApiRoute("settings/raw", ApiVersion.V1) class RawSettingsController(BaseAPIController): async def get(self): settings: Settings = self.application.digi_settings diff --git a/server/controllers/api/show/acts.py b/server/controllers/api/show/acts.py index cb226175..21ef879e 100644 --- a/server/controllers/api/show/acts.py +++ b/server/controllers/api/show/acts.py @@ -2,39 +2,41 @@ from tornado import escape -from models.show import Show, Act, Scene +from models.show import Act, Scene, Show from rbac.role import Role from schemas.schemas import ActSchema from utils.web.base_controller import BaseAPIController -from utils.web.web_decorators import requires_show, no_live_session from utils.web.route import ApiRoute, ApiVersion +from utils.web.web_decorators import no_live_session, requires_show -@ApiRoute('show/act', ApiVersion.V1) +@ApiRoute("show/act", ApiVersion.V1) class ActController(BaseAPIController): @requires_show def get(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] act_schema = ActSchema() with self.make_session() as session: show = session.query(Show).get(show_id) if show: - acts: List[Act] = session.query(Act).filter(Act.show_id == show.id).all() + acts: List[Act] = ( + session.query(Act).filter(Act.show_id == show.id).all() + ) acts = [act_schema.dump(c) for c in acts] self.set_status(200) - self.finish({'acts': acts}) + self.finish({"acts": acts}) else: self.set_status(404) - self.finish({'message': '404 show not found'}) + self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def post(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) @@ -42,22 +44,26 @@ async def post(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - name: str = data.get('name', None) + name: str = data.get("name", None) if not name: self.set_status(400) - await self.finish({'message': 'Name missing'}) + await self.finish({"message": "Name missing"}) return - interval_after: bool = data.get('interval_after', None) + interval_after: bool = data.get("interval_after", None) if interval_after is None: self.set_status(400) - await self.finish({'message': 'Interval after missing'}) + await self.finish({"message": "Interval after missing"}) return - previous_act_id: int = data.get('previous_act_id', None) + previous_act_id: int = data.get("previous_act_id", None) - new_act = Act(show_id=show.id, name=name, interval_after=interval_after, - previous_act_id=previous_act_id) + new_act = Act( + show_id=show.id, + name=name, + interval_after=interval_after, + previous_act_id=previous_act_id, + ) session.add(new_act) session.flush() @@ -67,18 +73,20 @@ async def post(self): session.commit() self.set_status(200) - await self.finish({'id': new_act.id, 'message': 'Successfully added act'}) + await self.finish( + {"id": new_act.id, "message": "Successfully added act"} + ) - await self.application.ws_send_to_all('NOOP', 'GET_ACT_LIST', {}) + await self.application.ws_send_to_all("NOOP", "GET_ACT_LIST", {}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def patch(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) @@ -86,51 +94,58 @@ async def patch(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - act_id = data.get('id', None) + act_id = data.get("id", None) if not act_id: self.set_status(400) - await self.finish({'message': 'ID missing'}) + await self.finish({"message": "ID missing"}) return entry: Act = session.get(Act, act_id) if entry: - name = data.get('name', None) + name = data.get("name", None) if not name: self.set_status(400) - await self.finish({'message': 'Name missing'}) + await self.finish({"message": "Name missing"}) return entry.name = name - interval_after: bool = data.get('interval_after', None) + interval_after: bool = data.get("interval_after", None) if interval_after is None: self.set_status(400) - await self.finish({'message': 'Interval after missing'}) + await self.finish({"message": "Interval after missing"}) return entry.interval_after = interval_after - previous_act_id: int = data.get('previous_act_id', None) + previous_act_id: int = data.get("previous_act_id", None) if previous_act_id: if previous_act_id == act_id: self.set_status(400) - await self.finish({'message': 'Previous act cannot be current act'}) + await self.finish( + {"message": "Previous act cannot be current act"} + ) return previous_act: Act = session.query(Act).get(previous_act_id) if not previous_act: self.set_status(400) - await self.finish({'message': 'Previous act not found'}) + await self.finish({"message": "Previous act not found"}) return act_indexes = [act_id] current_act: Act = previous_act - while current_act is not None and current_act.previous_act is not None: + while ( + current_act is not None + and current_act.previous_act is not None + ): if current_act.previous_act.id in act_indexes: self.set_status(400) - await self.finish({ - 'message': 'Previous act cannot form a circular ' - 'dependency between acts' - }) + await self.finish( + { + "message": "Previous act cannot form a circular " + "dependency between acts" + } + ) return current_act = current_act.previous_act @@ -139,22 +154,22 @@ async def patch(self): session.commit() self.set_status(200) - await self.finish({'message': 'Successfully updated act'}) + await self.finish({"message": "Successfully updated act"}) - await self.application.ws_send_to_all('NOOP', 'GET_ACT_LIST', {}) + await self.application.ws_send_to_all("NOOP", "GET_ACT_LIST", {}) else: self.set_status(404) - await self.finish({'message': '404 act not found'}) + await self.finish({"message": "404 act not found"}) return else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def delete(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show: Show = session.query(Show).get(show_id) @@ -162,10 +177,10 @@ async def delete(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - act_id = data.get('id', None) + act_id = data.get("id", None) if not act_id: self.set_status(400) - await self.finish({'message': 'ID missing'}) + await self.finish({"message": "ID missing"}) return entry: Act = session.get(Act, act_id) @@ -184,25 +199,25 @@ async def delete(self): session.commit() self.set_status(200) - await self.finish({'message': 'Successfully deleted act'}) + await self.finish({"message": "Successfully deleted act"}) - await self.application.ws_send_to_all('NOOP', 'GET_ACT_LIST', {}) + await self.application.ws_send_to_all("NOOP", "GET_ACT_LIST", {}) else: self.set_status(404) - await self.finish({'message': '404 act not found'}) + await self.finish({"message": "404 act not found"}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) -@ApiRoute('show/act/first_scene', ApiVersion.V1) +@ApiRoute("show/act/first_scene", ApiVersion.V1) class FirstSceneController(BaseAPIController): @requires_show @no_live_session async def post(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) @@ -210,45 +225,47 @@ async def post(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - act_id: int = data.get('act_id', None) + act_id: int = data.get("act_id", None) if not act_id: self.set_status(400) - await self.finish({'message': 'Act ID missing'}) + await self.finish({"message": "Act ID missing"}) return - if 'scene_id' not in data: + if "scene_id" not in data: self.set_status(400) - await self.finish({'message': 'Scene ID missing'}) + await self.finish({"message": "Scene ID missing"}) return - scene_id: int = data.get('scene_id', None) + scene_id: int = data.get("scene_id", None) if scene_id: scene: Scene = session.query(Scene).get(scene_id) if not scene: self.set_status(404) - await self.finish({'message': '404 scene not found'}) + await self.finish({"message": "404 scene not found"}) return if scene.previous_scene_id: self.set_status(400) - await self.finish({ - 'message': 'First scene cannot already have previous scene' - }) + await self.finish( + { + "message": "First scene cannot already have previous scene" + } + ) return act: Act = session.query(Act).get(act_id) if not act: self.set_status(404) - await self.finish({'message': 'Act not found'}) + await self.finish({"message": "Act not found"}) return act.first_scene_id = scene_id session.commit() self.set_status(200) - await self.finish({'message': 'Successfully set first scene'}) + await self.finish({"message": "Successfully set first scene"}) - await self.application.ws_send_to_all('NOOP', 'GET_ACT_LIST', {}) + await self.application.ws_send_to_all("NOOP", "GET_ACT_LIST", {}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) diff --git a/server/controllers/api/show/cast.py b/server/controllers/api/show/cast.py index 20b98458..a031b79f 100644 --- a/server/controllers/api/show/cast.py +++ b/server/controllers/api/show/cast.py @@ -1,20 +1,20 @@ from tornado import escape -from models.show import Show, Cast +from models.show import Cast, Show from rbac.role import Role from schemas.schemas import CastSchema from utils.web.base_controller import BaseAPIController -from utils.web.web_decorators import requires_show, no_live_session from utils.web.route import ApiRoute, ApiVersion +from utils.web.web_decorators import no_live_session, requires_show -@ApiRoute('show/cast', ApiVersion.V1) +@ApiRoute("show/cast", ApiVersion.V1) class CastController(BaseAPIController): @requires_show def get(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] cast_schema = CastSchema() with self.make_session() as session: @@ -22,16 +22,16 @@ def get(self): if show: cast = [cast_schema.dump(c) for c in show.cast_list] self.set_status(200) - self.finish({'cast': cast}) + self.finish({"cast": cast}) else: self.set_status(404) - self.finish({'message': '404 show not found'}) + self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def post(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) @@ -39,35 +39,41 @@ async def post(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - first_name = data.get('firstName', None) + first_name = data.get("firstName", None) if not first_name: self.set_status(400) - await self.finish({'message': 'First name missing'}) + await self.finish({"message": "First name missing"}) return - last_name = data.get('lastName', None) + last_name = data.get("lastName", None) if not last_name: self.set_status(400) - await self.finish({'message': 'Last name missing'}) + await self.finish({"message": "Last name missing"}) return - new_cast = Cast(show_id=show.id, first_name=first_name, last_name=last_name) + new_cast = Cast( + show_id=show.id, first_name=first_name, last_name=last_name + ) session.add(new_cast) session.commit() self.set_status(200) - await self.finish({'id': new_cast.id, 'message': 'Successfully added cast member'}) + await self.finish( + {"id": new_cast.id, "message": "Successfully added cast member"} + ) - await self.application.ws_send_to_all('GET_CAST_LIST', 'GET_CAST_LIST', {}) + await self.application.ws_send_to_all( + "GET_CAST_LIST", "GET_CAST_LIST", {} + ) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def patch(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) @@ -75,47 +81,49 @@ async def patch(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - cast_id = data.get('id', None) + cast_id = data.get("id", None) if not cast_id: self.set_status(400) - await self.finish({'message': 'ID missing'}) + await self.finish({"message": "ID missing"}) return entry: Cast = session.get(Cast, cast_id) if entry: - first_name = data.get('firstName', None) + first_name = data.get("firstName", None) if not first_name: self.set_status(400) - await self.finish({'message': 'First name missing'}) + await self.finish({"message": "First name missing"}) return entry.first_name = first_name - last_name = data.get('lastName', None) + last_name = data.get("lastName", None) if not last_name: self.set_status(400) - await self.finish({'message': 'Last name missing'}) + await self.finish({"message": "Last name missing"}) return entry.last_name = last_name session.commit() self.set_status(200) - await self.finish({'message': 'Successfully updated cast member'}) + await self.finish({"message": "Successfully updated cast member"}) - await self.application.ws_send_to_all('GET_CAST_LIST', 'GET_CAST_LIST', {}) + await self.application.ws_send_to_all( + "GET_CAST_LIST", "GET_CAST_LIST", {} + ) else: self.set_status(404) - await self.finish({'message': '404 cast member not found'}) + await self.finish({"message": "404 cast member not found"}) return else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def delete(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) @@ -123,10 +131,10 @@ async def delete(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - cast_id = data.get('id', None) + cast_id = data.get("id", None) if not cast_id: self.set_status(400) - await self.finish({'message': 'ID missing'}) + await self.finish({"message": "ID missing"}) return entry = session.get(Cast, cast_id) @@ -135,12 +143,14 @@ async def delete(self): session.commit() self.set_status(200) - await self.finish({'message': 'Successfully deleted cast member'}) + await self.finish({"message": "Successfully deleted cast member"}) - await self.application.ws_send_to_all('GET_CAST_LIST', 'GET_CAST_LIST', {}) + await self.application.ws_send_to_all( + "GET_CAST_LIST", "GET_CAST_LIST", {} + ) else: self.set_status(404) - await self.finish({'message': '404 cast member not found'}) + await self.finish({"message": "404 cast member not found"}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) diff --git a/server/controllers/api/show/characters.py b/server/controllers/api/show/characters.py index ad5391fc..f249be93 100644 --- a/server/controllers/api/show/characters.py +++ b/server/controllers/api/show/characters.py @@ -2,22 +2,22 @@ from tornado import escape -from models.script import Script, ScriptRevision, ScriptLine -from models.show import Show, Cast, Character, CharacterGroup +from models.script import Script, ScriptLine, ScriptRevision +from models.show import Cast, Character, CharacterGroup, Show from rbac.role import Role -from schemas.schemas import CharacterSchema, CharacterGroupSchema +from schemas.schemas import CharacterGroupSchema, CharacterSchema from utils.web.base_controller import BaseAPIController -from utils.web.web_decorators import requires_show, no_live_session from utils.web.route import ApiRoute, ApiVersion +from utils.web.web_decorators import no_live_session, requires_show -@ApiRoute('show/character', ApiVersion.V1) +@ApiRoute("show/character", ApiVersion.V1) class CharacterController(BaseAPIController): @requires_show def get(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] character_schema = CharacterSchema() with self.make_session() as session: @@ -25,16 +25,16 @@ def get(self): if show: characters = [character_schema.dump(c) for c in show.character_list] self.set_status(200) - self.finish({'characters': characters}) + self.finish({"characters": characters}) else: self.set_status(404) - self.finish({'message': '404 show not found'}) + self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def post(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show: Show = session.query(Show).get(show_id) @@ -42,39 +42,48 @@ async def post(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - name = data.get('name', None) + name = data.get("name", None) if not name: self.set_status(400) - await self.finish({'message': 'Name missing'}) + await self.finish({"message": "Name missing"}) return - description = data.get('description', None) - played_by = data.get('played_by', None) + description = data.get("description", None) + played_by = data.get("played_by", None) if played_by: cast_member = session.query(Cast).get(played_by) if not cast_member: self.set_status(404) - await self.finish({'message': '404 cast member found'}) + await self.finish({"message": "404 cast member found"}) return - new_character = Character(show_id=show.id, name=name, description=description, - played_by=played_by) + new_character = Character( + show_id=show.id, + name=name, + description=description, + played_by=played_by, + ) session.add(new_character) session.commit() self.set_status(200) - await self.finish({'id': new_character.id, 'message': 'Successfully added cast member'}) - - await self.application.ws_send_to_all('NOOP', 'GET_CHARACTER_LIST', {}) + await self.finish( + { + "id": new_character.id, + "message": "Successfully added cast member", + } + ) + + await self.application.ws_send_to_all("NOOP", "GET_CHARACTER_LIST", {}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def patch(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show: Show = session.query(Show).get(show_id) @@ -82,52 +91,54 @@ async def patch(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - character_id = data.get('id', None) + character_id = data.get("id", None) if not character_id: self.set_status(400) - await self.finish({'message': 'ID missing'}) + await self.finish({"message": "ID missing"}) return entry: Character = session.get(Character, character_id) if entry: - name = data.get('name', None) + name = data.get("name", None) if not name: self.set_status(400) - await self.finish({'message': 'Name missing'}) + await self.finish({"message": "Name missing"}) return entry.name = name - description = data.get('description', None) + description = data.get("description", None) entry.description = description - played_by = data.get('played_by', None) + played_by = data.get("played_by", None) if played_by: cast_member = session.query(Cast).get(played_by) if not cast_member: self.set_status(404) - await self.finish({'message': '404 cast member found'}) + await self.finish({"message": "404 cast member found"}) return entry.played_by = played_by session.commit() self.set_status(200) - await self.finish({'message': 'Successfully updated character'}) + await self.finish({"message": "Successfully updated character"}) - await self.application.ws_send_to_all('NOOP', 'GET_CHARACTER_LIST', {}) + await self.application.ws_send_to_all( + "NOOP", "GET_CHARACTER_LIST", {} + ) else: self.set_status(404) - await self.finish({'message': '404 character not found'}) + await self.finish({"message": "404 character not found"}) return else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def delete(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show: Show = session.query(Show).get(show_id) @@ -135,10 +146,10 @@ async def delete(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - character_id = data.get('id', None) + character_id = data.get("id", None) if not character_id: self.set_status(400) - await self.finish({'message': 'ID missing'}) + await self.finish({"message": "ID missing"}) return entry: Character = session.get(Character, character_id) @@ -147,34 +158,41 @@ async def delete(self): session.commit() self.set_status(200) - await self.finish({'message': 'Successfully deleted character'}) + await self.finish({"message": "Successfully deleted character"}) - await self.application.ws_send_to_all('NOOP', 'GET_CHARACTER_LIST', {}) + await self.application.ws_send_to_all( + "NOOP", "GET_CHARACTER_LIST", {} + ) else: self.set_status(404) - await self.finish({'message': '404 character not found'}) + await self.finish({"message": "404 character not found"}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) -@ApiRoute('show/character/stats', ApiVersion.V1) +@ApiRoute("show/character/stats", ApiVersion.V1) class CharacterStatsController(BaseAPIController): async def get(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show: Show = session.query(Show).get(show_id) if show: - script: Script = session.query(Script).filter(Script.show_id == show.id).first() + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) if script.current_revision: revision: ScriptRevision = session.query(ScriptRevision).get( - script.current_revision) + script.current_revision + ) else: self.set_status(400) - await self.finish({'message': 'Script does not have a current revision'}) + await self.finish( + {"message": "Script does not have a current revision"} + ) return line_counts = defaultdict(lambda: defaultdict(lambda: defaultdict(int))) @@ -186,43 +204,48 @@ async def get(self): if line_part.line_part_cuts is not None: continue if line_part.character_id: - line_counts[line_part.character_id][line.act_id][line.scene_id] += 1 + line_counts[line_part.character_id][line.act_id][ + line.scene_id + ] += 1 elif line_part.character_group_id: for character in line_part.character_group.characters: - line_counts[character.id][line.act_id][line.scene_id] += 1 + line_counts[character.id][line.act_id][ + line.scene_id + ] += 1 self.set_status(200) - await self.finish({'line_counts': line_counts}) + await self.finish({"line_counts": line_counts}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) -@ApiRoute('show/character/group', ApiVersion.V1) +@ApiRoute("show/character/group", ApiVersion.V1) class CharacterGroupController(BaseAPIController): @requires_show def get(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] character_group_schema = CharacterGroupSchema() with self.make_session() as session: show = session.query(Show).get(show_id) if show: - character_groups = [character_group_schema.dump(c) - for c in show.character_group_list] + character_groups = [ + character_group_schema.dump(c) for c in show.character_group_list + ] self.set_status(200) - self.finish({'character_groups': character_groups}) + self.finish({"character_groups": character_groups}) else: self.set_status(404) - self.finish({'message': '404 show not found'}) + self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def post(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) @@ -230,45 +253,52 @@ async def post(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - name = data.get('name', None) + name = data.get("name", None) if not name: self.set_status(400) - await self.finish({'message': 'Name missing'}) + await self.finish({"message": "Name missing"}) return - description = data.get('description', None) - character_list = data.get('characters', []) + description = data.get("description", None) + character_list = data.get("characters", []) character_model_list = [] for character_id in character_list: character = session.query(Character).get(character_id) if not character: self.set_status(404) - await self.finish({'message': f'Character {character_id} not found'}) + await self.finish( + {"message": f"Character {character_id} not found"} + ) return character_model_list.append(character) - session.add(CharacterGroup( - show_id=show_id, - name=name, - description=description, - characters=character_model_list)) + session.add( + CharacterGroup( + show_id=show_id, + name=name, + description=description, + characters=character_model_list, + ) + ) session.commit() self.set_status(200) - await self.finish({'message': 'Successfully added new character group'}) + await self.finish({"message": "Successfully added new character group"}) - await self.application.ws_send_to_all('NOOP', 'GET_CHARACTER_GROUP_LIST', {}) + await self.application.ws_send_to_all( + "NOOP", "GET_CHARACTER_GROUP_LIST", {} + ) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def delete(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show: Show = session.query(Show).get(show_id) @@ -276,10 +306,10 @@ async def delete(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - character_group_id = data.get('id', None) + character_group_id = data.get("id", None) if not character_group_id: self.set_status(400) - await self.finish({'message': 'ID missing'}) + await self.finish({"message": "ID missing"}) return entry: CharacterGroup = session.get(CharacterGroup, character_group_id) @@ -288,23 +318,25 @@ async def delete(self): session.commit() self.set_status(200) - await self.finish({'message': 'Successfully deleted character group'}) + await self.finish( + {"message": "Successfully deleted character group"} + ) - await self.application.ws_send_to_all('NOOP', - 'GET_CHARACTER_GROUP_LIST', - {}) + await self.application.ws_send_to_all( + "NOOP", "GET_CHARACTER_GROUP_LIST", {} + ) else: self.set_status(404) - await self.finish({'message': '404 character not found'}) + await self.finish({"message": "404 character not found"}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def patch(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show: Show = session.query(Show).get(show_id) @@ -312,32 +344,32 @@ async def patch(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - character_group_id = data.get('id', None) + character_group_id = data.get("id", None) if not character_group_id: self.set_status(400) - await self.finish({'message': 'ID missing'}) + await self.finish({"message": "ID missing"}) return entry: CharacterGroup = session.get(CharacterGroup, character_group_id) if entry: - name = data.get('name', None) + name = data.get("name", None) if not name: self.set_status(400) - await self.finish({'message': 'Name missing'}) + await self.finish({"message": "Name missing"}) return entry.name = name - entry.description = data.get('description', None) + entry.description = data.get("description", None) - character_list = data.get('characters', []) + character_list = data.get("characters", []) character_model_list = [] for character_id in character_list: character = session.query(Character).get(character_id) if not character: self.set_status(404) - await self.finish({ - 'message': f'Character {character_id} not found' - }) + await self.finish( + {"message": f"Character {character_id} not found"} + ) return character_model_list.append(character) entry.characters = character_model_list @@ -345,14 +377,17 @@ async def patch(self): session.commit() self.set_status(200) - await self.finish({'message': 'Successfully updated character group'}) + await self.finish( + {"message": "Successfully updated character group"} + ) await self.application.ws_send_to_all( - 'NOOP', 'GET_CHARACTER_GROUP_LIST', {}) + "NOOP", "GET_CHARACTER_GROUP_LIST", {} + ) else: self.set_status(404) - await self.finish({'message': '404 character group not found'}) + await self.finish({"message": "404 character group not found"}) return else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) diff --git a/server/controllers/api/show/cues.py b/server/controllers/api/show/cues.py index 532d2a25..4f887fd3 100644 --- a/server/controllers/api/show/cues.py +++ b/server/controllers/api/show/cues.py @@ -3,23 +3,23 @@ from tornado import escape -from models.cue import CueType, CueAssociation, Cue -from models.script import ScriptRevision, Script +from models.cue import Cue, CueAssociation, CueType +from models.script import Script, ScriptRevision from models.show import Show from rbac.role import Role -from schemas.schemas import CueTypeSchema, CueSchema +from schemas.schemas import CueSchema, CueTypeSchema from utils.web.base_controller import BaseAPIController -from utils.web.web_decorators import requires_show, no_live_session from utils.web.route import ApiRoute, ApiVersion +from utils.web.web_decorators import no_live_session, requires_show -@ApiRoute('show/cues/types', ApiVersion.V1) +@ApiRoute("show/cues/types", ApiVersion.V1) class CueTypesController(BaseAPIController): @requires_show def get(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] cue_type_schema = CueTypeSchema() with self.make_session() as session: @@ -27,16 +27,16 @@ def get(self): if show: cue_types = [cue_type_schema.dump(c) for c in show.cue_type_list] self.set_status(200) - self.finish({'cue_types': cue_types}) + self.finish({"cue_types": cue_types}) else: self.set_status(404) - self.finish({'message': '404 show not found'}) + self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def post(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) @@ -44,41 +44,44 @@ async def post(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - prefix: str = data.get('prefix', None) + prefix: str = data.get("prefix", None) if not prefix: self.set_status(400) - await self.finish({'message': 'Prefix missing'}) + await self.finish({"message": "Prefix missing"}) return - description: str = data.get('description', None) + description: str = data.get("description", None) - colour: str = data.get('colour', None) + colour: str = data.get("colour", None) if not colour: self.set_status(400) - await self.finish({'message': 'Colour missing'}) + await self.finish({"message": "Colour missing"}) return new_cuetype = CueType( show_id=show_id, prefix=prefix, description=description, - colour=colour) + colour=colour, + ) session.add(new_cuetype) session.commit() self.set_status(200) - await self.finish({'id': new_cuetype.id, 'message': 'Successfully added cue type'}) + await self.finish( + {"id": new_cuetype.id, "message": "Successfully added cue type"} + ) - await self.application.ws_send_to_all('NOOP', 'GET_CUE_TYPES', {}) + await self.application.ws_send_to_all("NOOP", "GET_CUE_TYPES", {}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def patch(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show: Show = session.query(Show).get(show_id) @@ -86,30 +89,30 @@ async def patch(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - cue_type_id = data.get('id', None) + cue_type_id = data.get("id", None) if not cue_type_id: self.set_status(400) - await self.finish({'message': 'ID missing'}) + await self.finish({"message": "ID missing"}) return cue_type: CueType = session.query(CueType).get(cue_type_id) if not cue_type: self.set_status(404) - await self.finish({'message': '404 cue type not found'}) + await self.finish({"message": "404 cue type not found"}) return - prefix: str = data.get('prefix', None) + prefix: str = data.get("prefix", None) if not prefix: self.set_status(400) - await self.finish({'message': 'Prefix missing'}) + await self.finish({"message": "Prefix missing"}) return - description: str = data.get('description', None) + description: str = data.get("description", None) - colour: str = data.get('colour', None) + colour: str = data.get("colour", None) if not colour: self.set_status(400) - await self.finish({'message': 'Colour missing'}) + await self.finish({"message": "Colour missing"}) return cue_type.prefix = prefix @@ -118,18 +121,18 @@ async def patch(self): session.commit() self.set_status(200) - await self.finish({'message': 'Successfully added cue type'}) + await self.finish({"message": "Successfully added cue type"}) - await self.application.ws_send_to_all('NOOP', 'GET_CUE_TYPES', {}) + await self.application.ws_send_to_all("NOOP", "GET_CUE_TYPES", {}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def delete(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show: Show = session.query(Show).get(show_id) @@ -137,10 +140,10 @@ async def delete(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - cue_type_id = data.get('id', None) + cue_type_id = data.get("id", None) if not cue_type_id: self.set_status(400) - await self.finish({'message': 'ID missing'}) + await self.finish({"message": "ID missing"}) return entry: CueType = session.get(CueType, cue_type_id) @@ -149,179 +152,203 @@ async def delete(self): session.commit() self.set_status(200) - await self.finish({'message': 'Successfully deleted cue type'}) + await self.finish({"message": "Successfully deleted cue type"}) - await self.application.ws_send_to_all('NOOP', 'GET_CUE_TYPES', {}) + await self.application.ws_send_to_all("NOOP", "GET_CUE_TYPES", {}) else: self.set_status(404) - await self.finish({'message': '404 cue type not found'}) + await self.finish({"message": "404 cue type not found"}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) -@ApiRoute('show/cues', ApiVersion.V1) +@ApiRoute("show/cues", ApiVersion.V1) class CueController(BaseAPIController): @requires_show def get(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] cue_schema = CueSchema() with self.make_session() as session: show = session.query(Show).get(show_id) if show: - script: Script = session.query(Script).filter(Script.show_id == show.id).first() + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) if script.current_revision: revision: ScriptRevision = session.query(ScriptRevision).get( - script.current_revision) + script.current_revision + ) else: self.set_status(400) - self.finish({'message': 'Script does not have a current revision'}) + self.finish({"message": "Script does not have a current revision"}) return - revision_cues: List[CueAssociation] = session.query( - CueAssociation).filter(CueAssociation.revision_id == revision.id).all() + revision_cues: List[CueAssociation] = ( + session.query(CueAssociation) + .filter(CueAssociation.revision_id == revision.id) + .all() + ) cues = collections.defaultdict(list) for association in revision_cues: cues[association.line_id].append(cue_schema.dump(association.cue)) self.set_status(200) - self.finish({'cues': cues}) + self.finish({"cues": cues}) else: self.set_status(404) - self.finish({'message': '404 show not found'}) + self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def post(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) if show: - script: Script = session.query(Script).filter(Script.show_id == show.id).first() + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) if script.current_revision: revision: ScriptRevision = session.query(ScriptRevision).get( - script.current_revision) + script.current_revision + ) else: self.set_status(400) - await self.finish({'message': 'Script does not have a current revision'}) + await self.finish( + {"message": "Script does not have a current revision"} + ) return data = escape.json_decode(self.request.body) - cue_type_id: int = data.get('cueType', None) + cue_type_id: int = data.get("cueType", None) if not cue_type_id: self.set_status(400) - await self.finish({'message': 'Cue Type missing'}) + await self.finish({"message": "Cue Type missing"}) return cue_type = session.query(CueType).get(cue_type_id) if not cue_type: self.set_status(400) - await self.finish({'message': 'Cue Type is not valid, or cannot be found'}) + await self.finish( + {"message": "Cue Type is not valid, or cannot be found"} + ) return self.requires_role(cue_type, Role.WRITE) - ident: str = data.get('ident', None) + ident: str = data.get("ident", None) if not ident: self.set_status(400) - await self.finish({'message': 'Identifier missing'}) + await self.finish({"message": "Identifier missing"}) return - line_id: int = data.get('lineId', None) + line_id: int = data.get("lineId", None) if not line_id: self.set_status(400) - await self.finish({'message': 'Line ID missing'}) + await self.finish({"message": "Line ID missing"}) return cue = Cue(cue_type_id=cue_type_id, ident=ident) session.add(cue) session.flush() - session.add(CueAssociation(revision_id=revision.id, line_id=line_id, - cue_id=cue.id)) + session.add( + CueAssociation( + revision_id=revision.id, line_id=line_id, cue_id=cue.id + ) + ) session.commit() self.set_status(200) - await self.finish({'message': 'Successfully added cue'}) + await self.finish({"message": "Successfully added cue"}) - await self.application.ws_send_to_all('NOOP', 'LOAD_CUES', {}) + await self.application.ws_send_to_all("NOOP", "LOAD_CUES", {}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def patch(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) if show: - script: Script = session.query(Script).filter(Script.show_id == show.id).first() + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) if script.current_revision: revision: ScriptRevision = session.query(ScriptRevision).get( - script.current_revision) + script.current_revision + ) else: self.set_status(400) - await self.finish({'message': 'Script does not have a current revision'}) + await self.finish( + {"message": "Script does not have a current revision"} + ) return data = escape.json_decode(self.request.body) - cue_id: int = data.get('cueId') + cue_id: int = data.get("cueId") if not cue_id: self.set_status(400) - await self.finish({'message': 'Cue ID missing'}) + await self.finish({"message": "Cue ID missing"}) return - cue_type_id: int = data.get('cueType', None) + cue_type_id: int = data.get("cueType", None) if not cue_type_id: self.set_status(400) - await self.finish({'message': 'Cue Type missing'}) + await self.finish({"message": "Cue Type missing"}) return cue_type = session.query(CueType).get(cue_type_id) if not cue_type: self.set_status(400) - await self.finish({'message': 'Cue Type is not valid, or cannot be found'}) + await self.finish( + {"message": "Cue Type is not valid, or cannot be found"} + ) return self.requires_role(cue_type, Role.WRITE) - ident: str = data.get('ident', None) + ident: str = data.get("ident", None) if not ident: self.set_status(400) - await self.finish({'message': 'Identifier missing'}) + await self.finish({"message": "Identifier missing"}) return - line_id: int = data.get('lineId', None) + line_id: int = data.get("lineId", None) if not line_id: self.set_status(400) - await self.finish({'message': 'Line ID missing'}) + await self.finish({"message": "Line ID missing"}) return cue: Cue = session.query(Cue).get(cue_id) if not cue: self.set_status(404) - await self.finish({'message': '404 cue not found'}) + await self.finish({"message": "404 cue not found"}) return current_association: CueAssociation = session.query(CueAssociation).get( - {'revision_id': revision.id, 'line_id': line_id, 'cue_id': cue_id}) + {"revision_id": revision.id, "line_id": line_id, "cue_id": cue_id} + ) if not current_association: self.set_status(400) - await self.finish({'message': 'Unable to load cue line data'}) + await self.finish({"message": "Unable to load cue line data"}) return if len(cue.revision_associations) == 1: @@ -330,8 +357,12 @@ async def patch(self): cue.ident = ident else: self.set_status(400) - await self.finish({'message': 'Cannot edit cue for a revision that is ' - 'not loaded'}) + await self.finish( + { + "message": "Cannot edit cue for a revision that is " + "not loaded" + } + ) return else: new_cue = Cue(ident=ident, cue_type_id=cue_type_id) @@ -342,64 +373,72 @@ async def patch(self): session.commit() self.set_status(200) - await self.finish({'message': 'Successfully edited cue'}) - await self.application.ws_send_to_all('NOOP', 'LOAD_CUES', {}) + await self.finish({"message": "Successfully edited cue"}) + await self.application.ws_send_to_all("NOOP", "LOAD_CUES", {}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def delete(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) if show: - script: Script = session.query(Script).filter(Script.show_id == show.id).first() + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) if script.current_revision: revision: ScriptRevision = session.query(ScriptRevision).get( - script.current_revision) + script.current_revision + ) else: self.set_status(400) - await self.finish({'message': 'Script does not have a current revision'}) + await self.finish( + {"message": "Script does not have a current revision"} + ) return data = escape.json_decode(self.request.body) - cue_id: int = data.get('cueId') + cue_id: int = data.get("cueId") if not cue_id: self.set_status(400) - await self.finish({'message': 'Cue ID missing'}) + await self.finish({"message": "Cue ID missing"}) return cue = session.query(Cue).get(cue_id) cue_type = session.query(CueType).get(cue.cue_type_id) self.requires_role(cue_type, Role.WRITE) - line_id: int = data.get('lineId') + line_id: int = data.get("lineId") if not line_id: self.set_status(400) - await self.finish({'message': 'Line ID missing'}) + await self.finish({"message": "Line ID missing"}) return association_object = session.query(CueAssociation).get( - {'revision_id': revision.id, 'line_id': line_id, 'cue_id': cue_id}) + {"revision_id": revision.id, "line_id": line_id, "cue_id": cue_id} + ) if association_object: session.delete(association_object) session.commit() self.set_status(200) - await self.finish({'message': 'Successfully deleted cue'}) - await self.application.ws_send_to_all('NOOP', 'LOAD_CUES', {}) + await self.finish({"message": "Successfully deleted cue"}) + await self.application.ws_send_to_all("NOOP", "LOAD_CUES", {}) else: self.set_status(400) - await self.finish({'message': 'Could not find cue association object'}) + await self.finish( + {"message": "Could not find cue association object"} + ) return else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) diff --git a/server/controllers/api/show/microphones.py b/server/controllers/api/show/microphones.py index 83f20225..cc7997e4 100644 --- a/server/controllers/api/show/microphones.py +++ b/server/controllers/api/show/microphones.py @@ -3,40 +3,43 @@ from tornado import escape from models.mics import Microphone, MicrophoneAllocation -from models.show import Show, Scene, Character +from models.show import Character, Scene, Show from rbac.role import Role -from schemas.schemas import MicrophoneSchema, MicrophoneAllocationSchema +from schemas.schemas import MicrophoneAllocationSchema, MicrophoneSchema from utils.web.base_controller import BaseAPIController -from utils.web.route import ApiVersion, ApiRoute -from utils.web.web_decorators import requires_show, no_live_session +from utils.web.route import ApiRoute, ApiVersion +from utils.web.web_decorators import no_live_session, requires_show -@ApiRoute('show/microphones', ApiVersion.V1) +@ApiRoute("show/microphones", ApiVersion.V1) class MicrophoneController(BaseAPIController): @requires_show def get(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] mic_schema = MicrophoneSchema() with self.make_session() as session: show = session.query(Show).get(show_id) if show: - mics: List[Microphone] = session.query(Microphone).filter( - Microphone.show_id == show.id).all() + mics: List[Microphone] = ( + session.query(Microphone) + .filter(Microphone.show_id == show.id) + .all() + ) mics = [mic_schema.dump(c) for c in mics] self.set_status(200) - self.finish({'microphones': mics}) + self.finish({"microphones": mics}) else: self.set_status(404) - self.finish({'message': '404 show not found'}) + self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def post(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) @@ -44,42 +47,48 @@ async def post(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - name: str = data.get('name', None) + name: str = data.get("name", None) if not name: self.set_status(400) - await self.finish({'message': 'Name missing'}) + await self.finish({"message": "Name missing"}) return - other_named = session.query(Microphone).filter( - Microphone.show_id == show_id, - Microphone.name == name).first() + other_named = ( + session.query(Microphone) + .filter(Microphone.show_id == show_id, Microphone.name == name) + .first() + ) if other_named: self.set_status(400) - await self.finish({'message': 'Name already taken'}) + await self.finish({"message": "Name already taken"}) return - description: str = data.get('description', None) + description: str = data.get("description", None) new_microphone = Microphone( - show_id=show_id, - name=name, - description=description) + show_id=show_id, name=name, description=description + ) session.add(new_microphone) session.commit() self.set_status(200) - await self.finish({'id': new_microphone.id, 'message': 'Successfully added microphone'}) - - await self.application.ws_send_to_all('NOOP', 'GET_MICROPHONE_LIST', {}) + await self.finish( + { + "id": new_microphone.id, + "message": "Successfully added microphone", + } + ) + + await self.application.ws_send_to_all("NOOP", "GET_MICROPHONE_LIST", {}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def patch(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show: Show = session.query(Show).get(show_id) @@ -87,52 +96,57 @@ async def patch(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - microphone_id = data.get('id', None) + microphone_id = data.get("id", None) if not microphone_id: self.set_status(400) - await self.finish({'message': 'ID missing'}) + await self.finish({"message": "ID missing"}) return microphone: Microphone = session.query(Microphone).get(microphone_id) if not microphone: self.set_status(404) - await self.finish({'message': '404 microphone not found'}) + await self.finish({"message": "404 microphone not found"}) return - name: str = data.get('name', None) + name: str = data.get("name", None) if not name: self.set_status(400) - await self.finish({'message': 'Name missing'}) + await self.finish({"message": "Name missing"}) return - other_named = session.query(Microphone).filter( - Microphone.show_id == show_id, - Microphone.name == name, - Microphone.id != microphone_id).first() + other_named = ( + session.query(Microphone) + .filter( + Microphone.show_id == show_id, + Microphone.name == name, + Microphone.id != microphone_id, + ) + .first() + ) if other_named: self.set_status(400) - await self.finish({'message': 'Name already taken'}) + await self.finish({"message": "Name already taken"}) return - description: str = data.get('description', None) + description: str = data.get("description", None) microphone.name = name microphone.description = description session.commit() self.set_status(200) - await self.finish({'message': 'Successfully updated microphone'}) + await self.finish({"message": "Successfully updated microphone"}) - await self.application.ws_send_to_all('NOOP', 'GET_MICROPHONE_LIST', {}) + await self.application.ws_send_to_all("NOOP", "GET_MICROPHONE_LIST", {}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def delete(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show: Show = session.query(Show).get(show_id) @@ -140,10 +154,10 @@ async def delete(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - microphone_id = data.get('id', None) + microphone_id = data.get("id", None) if not microphone_id: self.set_status(400) - await self.finish({'message': 'ID missing'}) + await self.finish({"message": "ID missing"}) return entry: Microphone = session.get(Microphone, microphone_id) @@ -152,48 +166,54 @@ async def delete(self): session.commit() self.set_status(200) - await self.finish({'message': 'Successfully deleted microphone'}) + await self.finish({"message": "Successfully deleted microphone"}) - await self.application.ws_send_to_all('NOOP', 'GET_MICROPHONE_LIST', {}) + await self.application.ws_send_to_all( + "NOOP", "GET_MICROPHONE_LIST", {} + ) else: self.set_status(404) - await self.finish({'message': '404 microphone not found'}) + await self.finish({"message": "404 microphone not found"}) else: self.set_status(404) - await self.finish({'message': '404 microphone not found'}) + await self.finish({"message": "404 microphone not found"}) -@ApiRoute('show/microphones/allocations', ApiVersion.V1) +@ApiRoute("show/microphones/allocations", ApiVersion.V1) class MicrophoneAllocationsController(BaseAPIController): @requires_show def get(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] allocation_schema = MicrophoneAllocationSchema() with self.make_session() as session: show = session.query(Show).get(show_id) if show: - mics: List[Microphone] = session.query(Microphone).filter( - Microphone.show_id == show.id).all() + mics: List[Microphone] = ( + session.query(Microphone) + .filter(Microphone.show_id == show.id) + .all() + ) allocations = {} for mic in mics: - allocations[mic.id] = [allocation_schema.dump(alloc) for - alloc in mic.allocations] + allocations[mic.id] = [ + allocation_schema.dump(alloc) for alloc in mic.allocations + ] self.set_status(200) - self.finish({'allocations': allocations}) + self.finish({"allocations": allocations}) else: self.set_status(404) - self.finish({'message': '404 show not found'}) + self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def patch(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show: Show = session.query(Show).get(show_id) @@ -205,42 +225,50 @@ async def patch(self): mic = session.query(Microphone).get(microphone_id) if not mic: self.set_status(404) - await self.finish({'message': '404 microphone not found'}) + await self.finish({"message": "404 microphone not found"}) return for scene_id in data[microphone_id]: scene = session.query(Scene).get(scene_id) if not scene: self.set_status(404) - await self.finish({'message': '404 scene not found'}) + await self.finish({"message": "404 scene not found"}) return - existing_allocation: MicrophoneAllocation = session.query( - MicrophoneAllocation).filter( - MicrophoneAllocation.scene_id == scene.id, - MicrophoneAllocation.mic_id == mic.id).first() + existing_allocation: MicrophoneAllocation = ( + session.query(MicrophoneAllocation) + .filter( + MicrophoneAllocation.scene_id == scene.id, + MicrophoneAllocation.mic_id == mic.id, + ) + .first() + ) character_id = data[microphone_id][scene_id] if character_id: character = session.query(Character).get(character_id) if not character: self.set_status(404) - await self.finish({'message': '404 character not found'}) + await self.finish( + {"message": "404 character not found"} + ) return if existing_allocation: existing_allocation.character_id = character.id else: - session.add(MicrophoneAllocation( - mic_id=mic.id, - scene_id=scene.id, - character_id=character.id) + session.add( + MicrophoneAllocation( + mic_id=mic.id, + scene_id=scene.id, + character_id=character.id, + ) ) elif existing_allocation: session.delete(existing_allocation) session.flush() - await self.application.ws_send_to_all('NOOP', 'GET_MIC_ALLOCATIONS', {}) + await self.application.ws_send_to_all("NOOP", "GET_MIC_ALLOCATIONS", {}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) diff --git a/server/controllers/api/show/scenes.py b/server/controllers/api/show/scenes.py index 633f6bdb..fb20efb7 100644 --- a/server/controllers/api/show/scenes.py +++ b/server/controllers/api/show/scenes.py @@ -2,39 +2,41 @@ from tornado import escape -from models.show import Show, Scene +from models.show import Scene, Show from rbac.role import Role from schemas.schemas import SceneSchema from utils.web.base_controller import BaseAPIController -from utils.web.web_decorators import requires_show, no_live_session from utils.web.route import ApiRoute, ApiVersion +from utils.web.web_decorators import no_live_session, requires_show -@ApiRoute('show/scene', ApiVersion.V1) +@ApiRoute("show/scene", ApiVersion.V1) class SceneController(BaseAPIController): @requires_show def get(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] scene_schema = SceneSchema() with self.make_session() as session: show = session.query(Show).get(show_id) if show: - scenes: List[Scene] = session.query(Scene).filter(Scene.show_id == show.id).all() + scenes: List[Scene] = ( + session.query(Scene).filter(Scene.show_id == show.id).all() + ) scenes = [scene_schema.dump(c) for c in scenes] self.set_status(200) - self.finish({'scenes': scenes}) + self.finish({"scenes": scenes}) else: self.set_status(404) - self.finish({'message': '404 show not found'}) + self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def post(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) @@ -42,39 +44,42 @@ async def post(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - act_id: int = data.get('act_id', None) + act_id: int = data.get("act_id", None) if not act_id: self.set_status(400) - await self.finish({'message': 'Act ID missing'}) + await self.finish({"message": "Act ID missing"}) return - name: str = data.get('name', None) + name: str = data.get("name", None) if not name: self.set_status(400) - await self.finish({'message': 'Name missing'}) + await self.finish({"message": "Name missing"}) return - previous_scene_id = data.get('previous_scene_id', None) + previous_scene_id = data.get("previous_scene_id", None) if previous_scene_id: previous_scene: Scene = session.query(Scene).get(previous_scene_id) if not previous_scene: self.set_status(400) - await self.finish({'message': 'Previous scene not found'}) + await self.finish({"message": "Previous scene not found"}) return if previous_scene.act_id != act_id: self.set_status(400) - await self.finish({ - 'message': 'Previous scene must be in the same act as new scene' - }) + await self.finish( + { + "message": "Previous scene must be in the same act as new scene" + } + ) return new_scene = Scene( show_id=show_id, act_id=act_id, name=name, - previous_scene_id=previous_scene_id) + previous_scene_id=previous_scene_id, + ) session.add(new_scene) session.flush() @@ -84,19 +89,21 @@ async def post(self): session.commit() self.set_status(200) - await self.finish({'id': new_scene.id, 'message': 'Successfully added scene'}) + await self.finish( + {"id": new_scene.id, "message": "Successfully added scene"} + ) - await self.application.ws_send_to_all('NOOP', 'GET_SCENE_LIST', {}) + await self.application.ws_send_to_all("NOOP", "GET_SCENE_LIST", {}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def delete(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show: Show = session.query(Show).get(show_id) @@ -104,10 +111,10 @@ async def delete(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - scene_id = data.get('id', None) + scene_id = data.get("id", None) if not scene_id: self.set_status(400) - await self.finish({'message': 'ID missing'}) + await self.finish({"message": "ID missing"}) return entry: Scene = session.get(Scene, scene_id) @@ -126,21 +133,21 @@ async def delete(self): session.commit() self.set_status(200) - await self.finish({'message': 'Successfully deleted scene'}) + await self.finish({"message": "Successfully deleted scene"}) - await self.application.ws_send_to_all('NOOP', 'GET_SCENE_LIST', {}) + await self.application.ws_send_to_all("NOOP", "GET_SCENE_LIST", {}) else: self.set_status(404) - await self.finish({'message': '404 scene not found'}) + await self.finish({"message": "404 scene not found"}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def patch(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show: Show = session.query(Show).get(show_id) @@ -148,59 +155,67 @@ async def patch(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - scene_id = data.get('scene_id', None) + scene_id = data.get("scene_id", None) if not scene_id: self.set_status(400) - await self.finish({'message': 'ID missing'}) + await self.finish({"message": "ID missing"}) return entry: Scene = session.get(Scene, scene_id) if entry: - act_id: int = data.get('act_id', None) + act_id: int = data.get("act_id", None) if not act_id: self.set_status(400) - await self.finish({'message': 'Act ID missing'}) + await self.finish({"message": "Act ID missing"}) return - name: str = data.get('name', None) + name: str = data.get("name", None) if not name: self.set_status(400) - await self.finish({'message': 'Name missing'}) + await self.finish({"message": "Name missing"}) return - previous_scene_id = data.get('previous_scene_id', None) + previous_scene_id = data.get("previous_scene_id", None) if previous_scene_id: if previous_scene_id == scene_id: self.set_status(400) - await self.finish({ - 'message': 'Previous scene cannot be current scene' - }) + await self.finish( + {"message": "Previous scene cannot be current scene"} + ) return - previous_scene: Scene = session.query(Scene).get(previous_scene_id) + previous_scene: Scene = session.query(Scene).get( + previous_scene_id + ) if not previous_scene: self.set_status(400) - await self.finish({'message': 'Previous scene not found'}) + await self.finish({"message": "Previous scene not found"}) return if previous_scene.act_id != act_id: self.set_status(400) - await self.finish({ - 'message': 'Previous scene must be in the same act as new scene' - }) + await self.finish( + { + "message": "Previous scene must be in the same act as new scene" + } + ) return scene_indexes = [scene_id] current_scene: Scene = previous_scene - while (current_scene is not None and - current_scene.previous_scene is not None): + while ( + current_scene is not None + and current_scene.previous_scene is not None + ): if current_scene.previous_scene.id in scene_indexes: self.set_status(400) - await self.finish({ - 'message': 'Previous scene cannot form a circular ' - 'dependency between scenes' - }) + await self.finish( + { + "message": "Previous scene cannot form a circular " + "dependency between scenes" + } + ) return current_scene = current_scene.previous_scene @@ -211,12 +226,12 @@ async def patch(self): session.commit() self.set_status(200) - await self.finish({'message': 'Successfully updated scene'}) + await self.finish({"message": "Successfully updated scene"}) - await self.application.ws_send_to_all('NOOP', 'GET_SCENE_LIST', {}) + await self.application.ws_send_to_all("NOOP", "GET_SCENE_LIST", {}) else: self.set_status(404) - await self.finish({'message': '404 scene not found'}) + await self.finish({"message": "404 scene not found"}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) diff --git a/server/controllers/api/show/script.py b/server/controllers/api/show/script.py index 5bf66d98..931cd335 100644 --- a/server/controllers/api/show/script.py +++ b/server/controllers/api/show/script.py @@ -5,179 +5,221 @@ from tornado import escape from models.cue import CueAssociation -from models.script import (Script, ScriptRevision, ScriptLine, ScriptLineRevisionAssociation, - ScriptLinePart, ScriptCuts, StageDirectionStyle) -from models.show import Show +from models.script import ( + Script, + ScriptCuts, + ScriptLine, + ScriptLinePart, + ScriptLineRevisionAssociation, + ScriptRevision, + StageDirectionStyle, +) from models.session import Session +from models.show import Show from rbac.role import Role -from schemas.schemas import ScriptRevisionsSchema, ScriptLineSchema, StageDirectionStyleSchema +from schemas.schemas import ( + ScriptLineSchema, + ScriptRevisionsSchema, + StageDirectionStyleSchema, +) from utils.web.base_controller import BaseAPIController -from utils.web.web_decorators import requires_show, no_live_session from utils.web.route import ApiRoute, ApiVersion +from utils.web.web_decorators import no_live_session, requires_show -@ApiRoute('show/script/config', ApiVersion.V1) +@ApiRoute("show/script/config", ApiVersion.V1) class ScriptStatusController(BaseAPIController): def get(self): with self.make_session() as session: - editors: List[Session] = session.query(Session).filter(Session.is_editor).all() + editors: List[Session] = ( + session.query(Session).filter(Session.is_editor).all() + ) if editors: current_editor = editors[0].internal_id else: current_editor = None data = { - 'canRequestEdit': len(editors) == 0, - 'currentEditor': current_editor + "canRequestEdit": len(editors) == 0, + "currentEditor": current_editor, } self.set_status(200) self.finish(data) -@ApiRoute('show/script/revisions', ApiVersion.V1) +@ApiRoute("show/script/revisions", ApiVersion.V1) class ScriptRevisionsController(BaseAPIController): @requires_show def get(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] revisions_schema = ScriptRevisionsSchema() with self.make_session() as session: show: Show = session.query(Show).get(show_id) if show: - script: Script = session.query(Script).filter(Script.show_id == show.id).first() + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) if script: revisions = [revisions_schema.dump(c) for c in script.revisions] self.set_status(200) - self.finish({ - 'current_revision': script.current_revision, - 'revisions': revisions - }) + self.finish( + { + "current_revision": script.current_revision, + "revisions": revisions, + } + ) else: self.set_status(404) - self.finish({'message': '404 script not found'}) + self.finish({"message": "404 script not found"}) else: self.set_status(404) - self.finish({'message': '404 show not found'}) + self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def post(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) if show: data = escape.json_decode(self.request.body) - script: Script = session.query(Script).filter(Script.show_id == show.id).first() + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) if not script: self.set_status(404) - await self.finish({'message': '404 script not found'}) + await self.finish({"message": "404 script not found"}) return self.requires_role(script, Role.WRITE) current_rev_id = script.current_revision if not current_rev_id: self.set_status(404) - await self.finish({'message': '404 script revision not found'}) + await self.finish({"message": "404 script revision not found"}) return - current_rev: ScriptRevision = session.query(ScriptRevision).get(current_rev_id) + current_rev: ScriptRevision = session.query(ScriptRevision).get( + current_rev_id + ) if not current_rev: self.set_status(404) - await self.finish({'message': '404 script revision not found'}) + await self.finish({"message": "404 script revision not found"}) return - max_rev = session.query(func.max(ScriptRevision.revision)).filter( - ScriptRevision.script_id == script.id).one()[0] + max_rev = ( + session.query(func.max(ScriptRevision.revision)) + .filter(ScriptRevision.script_id == script.id) + .one()[0] + ) - description: str = data.get('description', None) + description: str = data.get("description", None) if not description: self.set_status(400) - await self.finish({'message': 'Description missing'}) + await self.finish({"message": "Description missing"}) return now_time = datetime.utcnow() - new_rev = ScriptRevision(script_id=script.id, - revision=max_rev + 1, - created_at=now_time, - edited_at=now_time, - description=description, - previous_revision_id=current_rev.id) + new_rev = ScriptRevision( + script_id=script.id, + revision=max_rev + 1, + created_at=now_time, + edited_at=now_time, + description=description, + previous_revision_id=current_rev.id, + ) session.add(new_rev) session.flush() for line_association in current_rev.line_associations: - new_rev.line_associations.append(ScriptLineRevisionAssociation( - revision_id=new_rev.id, - line_id=line_association.line_id, - next_line_id=line_association.next_line_id, - previous_line_id=line_association.previous_line_id - )) + new_rev.line_associations.append( + ScriptLineRevisionAssociation( + revision_id=new_rev.id, + line_id=line_association.line_id, + next_line_id=line_association.next_line_id, + previous_line_id=line_association.previous_line_id, + ) + ) for cue_association in current_rev.cue_associations: - new_rev.cue_associations.append(CueAssociation( - revision_id=new_rev.id, - line_id=cue_association.line_id, - cue_id=cue_association.cue_id - )) + new_rev.cue_associations.append( + CueAssociation( + revision_id=new_rev.id, + line_id=cue_association.line_id, + cue_id=cue_association.cue_id, + ) + ) for cut_association in current_rev.line_part_cuts: - new_rev.line_part_cuts.append(ScriptCuts( - revision_id=new_rev.id, - line_part_id=cut_association.line_part_id, - )) + new_rev.line_part_cuts.append( + ScriptCuts( + revision_id=new_rev.id, + line_part_id=cut_association.line_part_id, + ) + ) script.current_revision = new_rev.id session.commit() self.set_status(200) - await self.finish({'id': new_rev.id, 'message': 'Successfully added script revision'}) - await self.application.ws_send_to_all('NOOP', 'GET_SCRIPT_REVISIONS', {}) + await self.finish( + {"id": new_rev.id, "message": "Successfully added script revision"} + ) + await self.application.ws_send_to_all( + "NOOP", "GET_SCRIPT_REVISIONS", {} + ) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def delete(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show: Show = session.query(Show).get(show_id) if show: data = escape.json_decode(self.request.body) - script: Script = session.query(Script).filter(Script.show_id == show.id).first() + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) if not script: self.set_status(404) - await self.finish({'message': '404 script not found'}) + await self.finish({"message": "404 script not found"}) return self.requires_role(script, Role.WRITE) - rev_id: int = data.get('rev_id', None) + rev_id: int = data.get("rev_id", None) if not rev_id: self.set_status(400) - await self.finish({'message': 'Revision missing'}) + await self.finish({"message": "Revision missing"}) return rev: ScriptRevision = session.query(ScriptRevision).get(rev_id) if not rev: self.set_status(404) - await self.finish({'message': 'Revision not found'}) + await self.finish({"message": "Revision not found"}) return if rev.script_id != script.id: self.set_status(400) - await self.finish({'message': 'Revision is not for the current script'}) + await self.finish( + {"message": "Revision is not for the current script"} + ) return if rev.revision == 1: self.set_status(400) - await self.finish({'message': 'Cannot delete first script revision'}) + await self.finish( + {"message": "Cannot delete first script revision"} + ) return changed_rev = False @@ -186,55 +228,68 @@ async def delete(self): if rev.previous_revision_id: script.current_revision = rev.previous_revision_id else: - first_rev: ScriptRevision = session.query(ScriptRevision).filter( - ScriptRevision.script_id == script.id, - ScriptRevision.revision == 1).one() + first_rev: ScriptRevision = ( + session.query(ScriptRevision) + .filter( + ScriptRevision.script_id == script.id, + ScriptRevision.revision == 1, + ) + .one() + ) script.current_revision = first_rev.id session.delete(rev) session.commit() self.set_status(200) - await self.finish({'message': 'Successfully deleted script revision'}) + await self.finish({"message": "Successfully deleted script revision"}) if changed_rev: - await self.application.ws_send_to_all('NOOP', 'SCRIPT_REVISION_CHANGED', {}) + await self.application.ws_send_to_all( + "NOOP", "SCRIPT_REVISION_CHANGED", {} + ) else: - await self.application.ws_send_to_all('NOOP', 'GET_SCRIPT_REVISIONS', {}) + await self.application.ws_send_to_all( + "NOOP", "GET_SCRIPT_REVISIONS", {} + ) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) -@ApiRoute('show/script/revisions/current', ApiVersion.V1) +@ApiRoute("show/script/revisions/current", ApiVersion.V1) class ScriptCurrentRevisionController(BaseAPIController): @requires_show def get(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show: Show = session.query(Show).get(show_id) if show: - script: Script = session.query(Script).filter(Script.show_id == show.id).first() + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) if script: self.set_status(200) - self.finish({ - 'current_revision': script.current_revision, - }) + self.finish( + { + "current_revision": script.current_revision, + } + ) else: self.set_status(404) - self.finish({'message': '404 script not found'}) + self.finish({"message": "404 script not found"}) else: self.set_status(404) - self.finish({'message': '404 show not found'}) + self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def post(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) @@ -242,52 +297,58 @@ async def post(self): self.requires_role(show, Role.WRITE) data = escape.json_decode(self.request.body) - script: Script = session.query(Script).filter(Script.show_id == show.id).first() + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) if not script: self.set_status(404) - await self.finish({'message': '404 script not found'}) + await self.finish({"message": "404 script not found"}) return - new_rev_id: int = data.get('new_rev_id', None) + new_rev_id: int = data.get("new_rev_id", None) if not new_rev_id: self.set_status(400) - await self.finish({'message': 'New revision missing'}) + await self.finish({"message": "New revision missing"}) return new_rev: ScriptRevision = session.query(ScriptRevision).get(new_rev_id) if not new_rev: self.set_status(404) - await self.finish({'message': 'New revision not found'}) + await self.finish({"message": "New revision not found"}) return if new_rev.script_id != script.id: self.set_status(400) - await self.finish({'message': 'New revision is not for the current script'}) + await self.finish( + {"message": "New revision is not for the current script"} + ) return script.current_revision = new_rev.id session.commit() self.set_status(200) - await self.finish({'message': 'Successfully changed script revision'}) - await self.application.ws_send_to_all('NOOP', 'SCRIPT_REVISION_CHANGED', {}) + await self.finish({"message": "Successfully changed script revision"}) + await self.application.ws_send_to_all( + "NOOP", "SCRIPT_REVISION_CHANGED", {} + ) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) -@ApiRoute('/show/script', ApiVersion.V1) +@ApiRoute("/show/script", ApiVersion.V1) class ScriptController(BaseAPIController): @requires_show def get(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] - page = self.get_query_argument('page', None) + page = self.get_query_argument("page", None) if not page: self.set_status(400) - self.finish({'message': 'Page not given'}) + self.finish({"message": "Page not given"}) return page = int(page) @@ -297,28 +358,40 @@ def get(self): with self.make_session() as session: show = session.query(Show).get(show_id) if show: - script: Script = session.query(Script).filter(Script.show_id == show.id).first() + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) if script.current_revision: revision: ScriptRevision = session.query(ScriptRevision).get( - script.current_revision) + script.current_revision + ) else: self.set_status(400) - self.finish({'message': 'Script does not have a current revision'}) + self.finish({"message": "Script does not have a current revision"}) return - revision_lines: List[ScriptLineRevisionAssociation] = session.query( - ScriptLineRevisionAssociation).filter( - ScriptLineRevisionAssociation.revision_id == revision.id, - ScriptLineRevisionAssociation.line.has(page=page)).all() + revision_lines: List[ScriptLineRevisionAssociation] = ( + session.query(ScriptLineRevisionAssociation) + .filter( + ScriptLineRevisionAssociation.revision_id == revision.id, + ScriptLineRevisionAssociation.line.has(page=page), + ) + .all() + ) first_line = None for line in revision_lines: - if (page == 1 and line.previous_line is None - or line.previous_line.page == page - 1): + if ( + page == 1 + and line.previous_line is None + or line.previous_line.page == page - 1 + ): if first_line: self.set_status(400) - self.finish({'message': 'Failed to establish page line order'}) + self.finish( + {"message": "Failed to establish page line order"} + ) return first_line = line @@ -331,52 +404,69 @@ def get(self): lines.append(line_schema.dump(line_revision.line)) line_revision = session.query(ScriptLineRevisionAssociation).get( - {'revision_id': revision.id, 'line_id': line_revision.next_line_id}) + { + "revision_id": revision.id, + "line_id": line_revision.next_line_id, + } + ) self.set_status(200) - self.finish({'lines': lines, 'page': page}) + self.finish({"lines": lines, "page": page}) else: self.set_status(404) - self.finish({'message': '404 show not found'}) + self.finish({"message": "404 show not found"}) return @staticmethod def _validate_line(line_json): - if line_json['stage_direction']: - if len(line_json['line_parts']) > 1: - return False, 'Stage directions can only have 1 line part' - line_part = line_json['line_parts'][0] - if line_part['character_id'] is not None: - return False, 'Stage directions cannot have characters' - if line_part['character_group_id'] is not None: - return False, 'Stage directions cannot have character groups' - if line_part['line_text'] is None: - return False, 'Stage directions must contain text' + if line_json["stage_direction"]: + if len(line_json["line_parts"]) > 1: + return False, "Stage directions can only have 1 line part" + line_part = line_json["line_parts"][0] + if line_part["character_id"] is not None: + return False, "Stage directions cannot have characters" + if line_part["character_group_id"] is not None: + return False, "Stage directions cannot have character groups" + if line_part["line_text"] is None: + return False, "Stage directions must contain text" else: - for line_part in line_json['line_parts']: - if line_part['line_text'] is None: - if len(line_json['line_parts']) == 1: - return False, 'Line parts must contain text' - if not any(lp['line_text'] is not None for lp in line_json['line_parts']): - return False, ('At least one line part in a multi part line must ' - 'contain text') - if line_part['character_id'] is None and line_part['character_group_id'] is None: - return False, 'Line parts must contain a character or character group' - if line_part['character_id'] and line_part['character_group_id']: - return False, 'Line parts cannot contain both a character and character group' - - return True, '' + for line_part in line_json["line_parts"]: + if line_part["line_text"] is None: + if len(line_json["line_parts"]) == 1: + return False, "Line parts must contain text" + if not any( + lp["line_text"] is not None for lp in line_json["line_parts"] + ): + return False, ( + "At least one line part in a multi part line must " + "contain text" + ) + if ( + line_part["character_id"] is None + and line_part["character_group_id"] is None + ): + return ( + False, + "Line parts must contain a character or character group", + ) + if line_part["character_id"] and line_part["character_group_id"]: + return ( + False, + "Line parts cannot contain both a character and character group", + ) + + return True, "" @requires_show @no_live_session async def post(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] - page = self.get_query_argument('page', None) + page = self.get_query_argument("page", None) if not page: self.set_status(400) - await self.finish({'message': 'Page not given'}) + await self.finish({"message": "Page not given"}) return page = int(page) @@ -384,15 +474,20 @@ async def post(self): with self.make_session() as session: show = session.query(Show).get(show_id) if show: - script: Script = session.query(Script).filter(Script.show_id == show.id).first() + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) self.requires_role(script, Role.WRITE) if script.current_revision: revision: ScriptRevision = session.query(ScriptRevision).get( - script.current_revision) + script.current_revision + ) else: self.set_status(400) - await self.finish({'message': 'Script does not have a current revision'}) + await self.finish( + {"message": "Script does not have a current revision"} + ) return lines = escape.json_decode(self.request.body) @@ -405,53 +500,66 @@ async def post(self): if not valid_status: session.rollback() self.set_status(400) - await self.finish({'message': valid_reason}) + await self.finish({"message": valid_reason}) return # Create the initial line object, and flush it to the database as we need # the ID for further in the loop - line_obj = ScriptLine(act_id=line['act_id'], - scene_id=line['scene_id'], - page=line['page'], - stage_direction=line['stage_direction']) + line_obj = ScriptLine( + act_id=line["act_id"], + scene_id=line["scene_id"], + page=line["page"], + stage_direction=line["stage_direction"], + ) session.add(line_obj) session.flush() # Line revision object to keep track of that thing - line_revision = ScriptLineRevisionAssociation(revision_id=revision.id, - line_id=line_obj.id) + line_revision = ScriptLineRevisionAssociation( + revision_id=revision.id, line_id=line_obj.id + ) session.add(line_revision) session.flush() if index == 0 and page > 1: # First line and not the first page, so need to get the last line of the # previous page and set its next line to this one - prev_page_lines: List[ScriptLineRevisionAssociation] = session.query( - ScriptLineRevisionAssociation).filter( - ScriptLineRevisionAssociation.revision_id == revision.id, - ScriptLineRevisionAssociation.line.has(page=page - 1)).all() + prev_page_lines: List[ScriptLineRevisionAssociation] = ( + session.query(ScriptLineRevisionAssociation) + .filter( + ScriptLineRevisionAssociation.revision_id + == revision.id, + ScriptLineRevisionAssociation.line.has(page=page - 1), + ) + .all() + ) if not prev_page_lines: session.rollback() self.set_status(400) - await self.finish({ - 'message': 'Previous page does not contain any lines' - }) + await self.finish( + {"message": "Previous page does not contain any lines"} + ) return # Perform some iteration here to establish the first line of the script # of the previous page first_line = None for prev_line in prev_page_lines: - if (prev_line.previous_line is None or - prev_line.previous_line.page == prev_line.line.page - 1): + if ( + prev_line.previous_line is None + or prev_line.previous_line.page + == prev_line.line.page - 1 + ): if first_line: session.rollback() self.set_status(400) - await self.finish({ - 'message': 'Failed to establish page line order for ' - 'previous page' - }) + await self.finish( + { + "message": "Failed to establish page line order for " + "previous page" + } + ) return first_line = prev_line @@ -465,10 +573,14 @@ async def post(self): break previous_lines.append(prev_line) - prev_line = session.query(ScriptLineRevisionAssociation).get({ - 'revision_id': revision.id, - 'line_id': prev_line.next_line_id - }) + prev_line = session.query( + ScriptLineRevisionAssociation + ).get( + { + "revision_id": revision.id, + "line_id": prev_line.next_line_id, + } + ) previous_lines[-1].next_line_id = line_obj.id line_revision.previous_line_id = previous_lines[-1].line_id @@ -481,13 +593,14 @@ async def post(self): session.flush() # Construct the line part objects and add these to the line itself - for line_part in line['line_parts']: - part_obj = ScriptLinePart(line_id=line_obj.id, - part_index=line_part['part_index'], - character_id=line_part['character_id'], - character_group_id=line_part[ - 'character_group_id'], - line_text=line_part['line_text']) + for line_part in line["line_parts"]: + part_obj = ScriptLinePart( + line_id=line_obj.id, + part_index=line_part["part_index"], + character_id=line_part["character_id"], + character_group_id=line_part["character_group_id"], + line_text=line_part["line_text"], + ) session.add(part_obj) line_obj.line_parts.append(part_obj) @@ -505,24 +618,27 @@ async def post(self): session.commit() else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) return @staticmethod def _create_new_line(session, revision, line, previous_line, with_association=True): # Create the line object - line_obj = ScriptLine(act_id=line['act_id'], - scene_id=line['scene_id'], - page=line['page'], - stage_direction=line['stage_direction']) + line_obj = ScriptLine( + act_id=line["act_id"], + scene_id=line["scene_id"], + page=line["page"], + stage_direction=line["stage_direction"], + ) session.add(line_obj) session.flush() line_association = None if with_association: # Line revision object to keep track of that thing - line_association = ScriptLineRevisionAssociation(revision_id=revision.id, - line_id=line_obj.id) + line_association = ScriptLineRevisionAssociation( + revision_id=revision.id, line_id=line_obj.id + ) session.add(line_association) session.flush() @@ -533,12 +649,14 @@ def _create_new_line(session, revision, line, previous_line, with_association=Tr session.flush() # Construct the line part objects and add these to the line itself - for line_part in line['line_parts']: - part_obj = ScriptLinePart(line_id=line_obj.id, - part_index=line_part['part_index'], - character_id=line_part['character_id'], - character_group_id=line_part['character_group_id'], - line_text=line_part['line_text']) + for line_part in line["line_parts"]: + part_obj = ScriptLinePart( + line_id=line_obj.id, + part_index=line_part["part_index"], + character_id=line_part["character_id"], + character_group_id=line_part["character_group_id"], + line_text=line_part["line_text"], + ) session.add(part_obj) line_obj.line_parts.append(part_obj) @@ -551,12 +669,12 @@ def _create_new_line(session, revision, line, previous_line, with_association=Tr @no_live_session async def patch(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] - page = self.get_query_argument('page', None) + page = self.get_query_argument("page", None) if not page: self.set_status(400) - await self.finish({'message': 'Page not given'}) + await self.finish({"message": "Page not given"}) return page = int(page) @@ -564,117 +682,151 @@ async def patch(self): with self.make_session() as session: show = session.query(Show).get(show_id) if show: - script: Script = session.query(Script).filter(Script.show_id == show.id).first() + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) self.requires_role(script, Role.WRITE) if script.current_revision: revision: ScriptRevision = session.query(ScriptRevision).get( - script.current_revision) + script.current_revision + ) else: self.set_status(400) - await self.finish({'message': 'Script does not have a current revision'}) + await self.finish( + {"message": "Script does not have a current revision"} + ) return request_body = escape.json_decode(self.request.body) - lines = request_body.get('page', None) + lines = request_body.get("page", None) if lines is None: self.set_status(400) - await self.finish({ - 'message': 'Malformed request body, could not find `page` data' - }) + await self.finish( + { + "message": "Malformed request body, could not find `page` data" + } + ) return - status = request_body.get('status', None) + status = request_body.get("status", None) if status is None: self.set_status(400) - await self.finish({ - 'message': 'Malformed request body, could not find `status` data' - }) + await self.finish( + { + "message": "Malformed request body, could not find `status` data" + } + ) return # If we are editing a page other than the first page, we need to get the previous # line based on the last line from the previous page to ensure that any edits made # to the first line of this page are reflected properly if page > 1: - if lines[0]['id'] is not None: + if lines[0]["id"] is not None: first_line = session.query(ScriptLineRevisionAssociation).get( - {'revision_id': revision.id, - 'line_id': lines[0]['id']}) + {"revision_id": revision.id, "line_id": lines[0]["id"]} + ) if not first_line: session.rollback() self.set_status(400) - await self.finish({'message': 'Unable to load line data for first line'}) + await self.finish( + {"message": "Unable to load line data for first line"} + ) return if not first_line.previous_line: session.rollback() self.set_status(400) - await self.finish({ - 'message': 'Unable to establish page line order - ' - 'first line on this page does not have a previous line' - }) + await self.finish( + { + "message": "Unable to establish page line order - " + "first line on this page does not have a previous line" + } + ) return - previous_line = session.query(ScriptLineRevisionAssociation).get( - {'revision_id': revision.id, - 'line_id': first_line.previous_line.id}) + previous_line = session.query( + ScriptLineRevisionAssociation + ).get( + { + "revision_id": revision.id, + "line_id": first_line.previous_line.id, + } + ) if not previous_line: session.rollback() self.set_status(400) - await self.finish({ - 'message': 'Unable to establish page line order - ' - 'could not find previous line data for first line on this ' - 'page' - }) + await self.finish( + { + "message": "Unable to establish page line order - " + "could not find previous line data for first line on this " + "page" + } + ) return else: session.rollback() self.set_status(400) - await self.finish({'message': 'Cannot establish line order as first line has ' - 'no ID'}) + await self.finish( + { + "message": "Cannot establish line order as first line has " + "no ID" + } + ) return else: previous_line: Optional[ScriptLineRevisionAssociation] = None for index, line in enumerate(lines): - if index in status['added']: + if index in status["added"]: # Validate the line valid_status, valid_reason = self._validate_line(line) if not valid_status: session.rollback() self.set_status(400) - await self.finish({'message': valid_reason}) + await self.finish({"message": valid_reason}) return line_association, line_object = self._create_new_line( - session, revision, line, previous_line) + session, revision, line, previous_line + ) previous_line = line_association - elif index in status['inserted']: + elif index in status["inserted"]: # Validate the line valid_status, valid_reason = self._validate_line(line) if not valid_status: session.rollback() self.set_status(400) - await self.finish({'message': valid_reason}) + await self.finish({"message": valid_reason}) return line_association, line_object = self._create_new_line( - session, revision, line, previous_line, with_association=False) + session, + revision, + line, + previous_line, + with_association=False, + ) line_association = ScriptLineRevisionAssociation( - revision_id=revision.id, - line_id=line_object.id) + revision_id=revision.id, line_id=line_object.id + ) line_association.previous_line = previous_line.line session.add(line_association) session.flush() if previous_line.next_line: - next_association: ScriptLineRevisionAssociation = session.query( - ScriptLineRevisionAssociation).get( - {'revision_id': revision.id, - 'line_id': previous_line.next_line.id}) + next_association: ( + ScriptLineRevisionAssociation + ) = session.query(ScriptLineRevisionAssociation).get( + { + "revision_id": revision.id, + "line_id": previous_line.next_line.id, + } + ) next_association.previous_line = line_object line_association.next_line = next_association.line @@ -683,76 +835,106 @@ async def patch(self): session.flush() previous_line = line_association - elif index in status['deleted']: + elif index in status["deleted"]: curr_association: ScriptLineRevisionAssociation = session.query( - ScriptLineRevisionAssociation).get( - {'revision_id': revision.id, 'line_id': line['id']}) + ScriptLineRevisionAssociation + ).get({"revision_id": revision.id, "line_id": line["id"]}) # Logic for handling next/previous line associations - if curr_association.next_line and curr_association.previous_line: + if ( + curr_association.next_line + and curr_association.previous_line + ): # Next line and previous line, so need to update both - next_association: ScriptLineRevisionAssociation = session.query( - ScriptLineRevisionAssociation).get( - {'revision_id': revision.id, - 'line_id': curr_association.next_line.id}) - next_association.previous_line = curr_association.previous_line + next_association: ( + ScriptLineRevisionAssociation + ) = session.query(ScriptLineRevisionAssociation).get( + { + "revision_id": revision.id, + "line_id": curr_association.next_line.id, + } + ) + next_association.previous_line = ( + curr_association.previous_line + ) session.flush() - prev_association: ScriptLineRevisionAssociation = session.query( - ScriptLineRevisionAssociation).get( - {'revision_id': revision.id, - 'line_id': curr_association.previous_line.id}) + prev_association: ( + ScriptLineRevisionAssociation + ) = session.query(ScriptLineRevisionAssociation).get( + { + "revision_id": revision.id, + "line_id": curr_association.previous_line.id, + } + ) prev_association.next_line = next_association.line session.flush() elif curr_association.next_line: # No previous line, so need to update next line only - next_association: ScriptLineRevisionAssociation = session.query( - ScriptLineRevisionAssociation).get( - {'revision_id': revision.id, - 'line_id': curr_association.next_line.id}) + next_association: ( + ScriptLineRevisionAssociation + ) = session.query(ScriptLineRevisionAssociation).get( + { + "revision_id": revision.id, + "line_id": curr_association.next_line.id, + } + ) next_association.previous_line = None session.flush() elif curr_association.previous_line: # No next line, so need to update previous line only - prev_association: ScriptLineRevisionAssociation = session.query( - ScriptLineRevisionAssociation).get( - {'revision_id': revision.id, - 'line_id': curr_association.previous_line.id}) + prev_association: ( + ScriptLineRevisionAssociation + ) = session.query(ScriptLineRevisionAssociation).get( + { + "revision_id": revision.id, + "line_id": curr_association.previous_line.id, + } + ) prev_association.next_line = None session.flush() session.delete(curr_association) - elif index in status['updated']: + elif index in status["updated"]: # Validate the line valid_status, valid_reason = self._validate_line(line) if not valid_status: session.rollback() self.set_status(400) - await self.finish({'message': valid_reason}) + await self.finish({"message": valid_reason}) return curr_association: ScriptLineRevisionAssociation = session.query( - ScriptLineRevisionAssociation).get( - {'revision_id': revision.id, 'line_id': line['id']}) + ScriptLineRevisionAssociation + ).get({"revision_id": revision.id, "line_id": line["id"]}) curr_line = curr_association.line if not curr_association: session.rollback() self.set_status(500) - await self.finish({'message': 'Unable to load line data'}) + await self.finish({"message": "Unable to load line data"}) return line_association, line_object = self._create_new_line( - session, revision, line, previous_line, with_association=False) + session, + revision, + line, + previous_line, + with_association=False, + ) curr_association.line = line_object if previous_line: previous_line.next_line = line_object curr_association.previous_line = previous_line.line if curr_association.next_line: - next_association: ScriptLineRevisionAssociation = session.query( - ScriptLineRevisionAssociation).get( - {'revision_id': revision.id, - 'line_id': curr_association.next_line.id}) + next_association: ( + ScriptLineRevisionAssociation + ) = session.query(ScriptLineRevisionAssociation).get( + { + "revision_id": revision.id, + "line_id": curr_association.next_line.id, + } + ) next_association.previous_line = line_object session.flush() @@ -763,81 +945,94 @@ async def patch(self): previous_line = curr_association else: - previous_line = session.query(ScriptLineRevisionAssociation).get( - {'revision_id': revision.id, 'line_id': line['id']}) + previous_line = session.query( + ScriptLineRevisionAssociation + ).get({"revision_id": revision.id, "line_id": line["id"]}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) return -@ApiRoute('/show/script/cuts', ApiVersion.V1) +@ApiRoute("/show/script/cuts", ApiVersion.V1) class ScriptCutsController(BaseAPIController): @requires_show def get(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) if show: - script: Script = session.query(Script).filter(Script.show_id == show.id).first() + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) if script.current_revision: revision: ScriptRevision = session.query(ScriptRevision).get( - script.current_revision) + script.current_revision + ) else: self.set_status(400) - self.finish({'message': 'Script does not have a current revision'}) + self.finish({"message": "Script does not have a current revision"}) return - line_cuts = session.query(ScriptCuts).filter( - ScriptCuts.revision_id == revision.id).all() + line_cuts = ( + session.query(ScriptCuts) + .filter(ScriptCuts.revision_id == revision.id) + .all() + ) line_cuts = [line_cut.line_part_id for line_cut in line_cuts] self.set_status(200) - self.finish({ - 'cuts': line_cuts - }) + self.finish({"cuts": line_cuts}) else: self.set_status(404) - self.finish({'message': '404 show not found'}) + self.finish({"message": "404 show not found"}) return @requires_show @no_live_session def put(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) if show: - script: Script = session.query(Script).filter(Script.show_id == show.id).first() + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) self.requires_role(script, Role.WRITE) if script.current_revision: revision: ScriptRevision = session.query(ScriptRevision).get( - script.current_revision) + script.current_revision + ) else: self.set_status(400) - self.finish({'message': 'Script does not have a current revision'}) + self.finish({"message": "Script does not have a current revision"}) return request_body = escape.json_decode(self.request.body) - cuts = request_body.get('cuts', None) + cuts = request_body.get("cuts", None) if cuts is None: self.set_status(400) - self.finish({ - 'message': 'Malformed request body, could not find `cuts` data' - }) + self.finish( + { + "message": "Malformed request body, could not find `cuts` data" + } + ) return - line_cuts: List[ScriptCuts] = session.query(ScriptCuts).filter( - ScriptCuts.revision_id == revision.id).all() + line_cuts: List[ScriptCuts] = ( + session.query(ScriptCuts) + .filter(ScriptCuts.revision_id == revision.id) + .all() + ) # Remove any cuts not in the list existing_cuts = [] @@ -850,115 +1045,133 @@ def put(self): # Add new cuts cuts_to_add = [cut for cut in cuts if cut not in existing_cuts] for cut in cuts_to_add: - session.add(ScriptCuts( - line_part_id=cut, - revision_id=revision.id, - )) + session.add( + ScriptCuts( + line_part_id=cut, + revision_id=revision.id, + ) + ) session.commit() -@ApiRoute('/show/script/max_page', ApiVersion.V1) +@ApiRoute("/show/script/max_page", ApiVersion.V1) class ScriptMaxPageController(BaseAPIController): @requires_show def get(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) if show: - script: Script = session.query(Script).filter(Script.show_id == show.id).first() + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) if script.current_revision: revision: ScriptRevision = session.query(ScriptRevision).get( - script.current_revision) + script.current_revision + ) else: self.set_status(400) - self.finish({'message': 'Script does not have a current revision'}) + self.finish({"message": "Script does not have a current revision"}) return - line_ids = session.query(ScriptLineRevisionAssociation).with_entities( - ScriptLineRevisionAssociation.line_id).filter( - ScriptLineRevisionAssociation.revision_id == revision.id) - max_page = session.query(ScriptLine).with_entities( - func.max(ScriptLine.page)).where( - ScriptLine.id.in_(line_ids)).first()[0] + line_ids = ( + session.query(ScriptLineRevisionAssociation) + .with_entities(ScriptLineRevisionAssociation.line_id) + .filter(ScriptLineRevisionAssociation.revision_id == revision.id) + ) + max_page = ( + session.query(ScriptLine) + .with_entities(func.max(ScriptLine.page)) + .where(ScriptLine.id.in_(line_ids)) + .first()[0] + ) if max_page is None: max_page = 0 self.set_status(200) - self.finish({ - 'max_page': max_page - }) + self.finish({"max_page": max_page}) else: self.set_status(404) - self.finish({'message': '404 show not found'}) + self.finish({"message": "404 show not found"}) return -@ApiRoute('/show/script/stage_direction_styles', ApiVersion.V1) + +@ApiRoute("/show/script/stage_direction_styles", ApiVersion.V1) class StageDirectionStylesController(BaseAPIController): @requires_show def get(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] stage_direction_style_schema = StageDirectionStyleSchema() with self.make_session() as session: show = session.query(Show).get(show_id) if show: - script: Script = session.query(Script).filter(Script.show_id == show.id).first() - stage_direction_styles = [stage_direction_style_schema.dump(style) for style in script.stage_direction_styles] + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) + stage_direction_styles = [ + stage_direction_style_schema.dump(style) + for style in script.stage_direction_styles + ] self.set_status(200) - self.finish({'styles': stage_direction_styles}) + self.finish({"styles": stage_direction_styles}) else: self.set_status(404) - self.finish({'message': '404 show not found'}) + self.finish({"message": "404 show not found"}) return @requires_show @no_live_session async def post(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) if show: - script: Script = session.query(Script).filter(Script.show_id == show.id).first() + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) self.requires_role(script, Role.WRITE) data = escape.json_decode(self.request.body) - description: str = data.get('description', None) + description: str = data.get("description", None) if not description: self.set_status(400) - await self.finish({'message': 'Description missing'}) + await self.finish({"message": "Description missing"}) return - bold: bool = data.get('bold', False) - italic: bool = data.get('italic', False) - underline: bool = data.get('underline', False) + bold: bool = data.get("bold", False) + italic: bool = data.get("italic", False) + underline: bool = data.get("underline", False) - text_format: str = data.get('textFormat', None) - if not text_format or text_format not in ['default', 'upper', 'lower']: + text_format: str = data.get("textFormat", None) + if not text_format or text_format not in ["default", "upper", "lower"]: self.set_status(400) - await self.finish({'message': 'Text format missing or invalid'}) + await self.finish({"message": "Text format missing or invalid"}) return - text_colour: str = data.get('textColour', None) + text_colour: str = data.get("textColour", None) if not text_colour: self.set_status(400) - await self.finish({'message': 'Text colour missing'}) + await self.finish({"message": "Text colour missing"}) return - enable_background_colour: bool = data.get('enableBackgroundColour', False) - background_colour: str = data.get('backgroundColour', None) + enable_background_colour: bool = data.get( + "enableBackgroundColour", False + ) + background_colour: str = data.get("backgroundColour", None) if enable_background_colour and not background_colour: self.set_status(400) - await self.finish({'message': 'Background colour missing'}) + await self.finish({"message": "Background colour missing"}) return new_style = StageDirectionStyle( @@ -976,65 +1189,80 @@ async def post(self): session.commit() self.set_status(200) - await self.finish({'id': new_style.id, 'message': 'Successfully added stage direction style'}) + await self.finish( + { + "id": new_style.id, + "message": "Successfully added stage direction style", + } + ) - await self.application.ws_send_to_all('NOOP', 'GET_STAGE_DIRECTION_STYLES', {}) + await self.application.ws_send_to_all( + "NOOP", "GET_STAGE_DIRECTION_STYLES", {} + ) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def patch(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show: Show = session.query(Show).get(show_id) if show: - script: Script = session.query(Script).filter(Script.show_id == show.id).first() + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) self.requires_role(script, Role.WRITE) data = escape.json_decode(self.request.body) - style_id = data.get('id', None) + style_id = data.get("id", None) if not style_id: self.set_status(400) - await self.finish({'message': 'ID missing'}) + await self.finish({"message": "ID missing"}) return - style: StageDirectionStyle = session.query(StageDirectionStyle).get(style_id) + style: StageDirectionStyle = session.query(StageDirectionStyle).get( + style_id + ) if not style: self.set_status(404) - await self.finish({'message': '404 stage direction style not found'}) + await self.finish( + {"message": "404 stage direction style not found"} + ) return - description: str = data.get('description', None) + description: str = data.get("description", None) if not description: self.set_status(400) - await self.finish({'message': 'Description missing'}) + await self.finish({"message": "Description missing"}) return - bold: bool = data.get('bold', False) - italic: bool = data.get('italic', False) - underline: bool = data.get('underline', False) + bold: bool = data.get("bold", False) + italic: bool = data.get("italic", False) + underline: bool = data.get("underline", False) - text_format: str = data.get('textFormat', None) - if not text_format or text_format not in ['default', 'upper', 'lower']: + text_format: str = data.get("textFormat", None) + if not text_format or text_format not in ["default", "upper", "lower"]: self.set_status(400) - await self.finish({'message': 'Text format missing or invalid'}) + await self.finish({"message": "Text format missing or invalid"}) return - text_colour: str = data.get('textColour', None) + text_colour: str = data.get("textColour", None) if not text_colour: self.set_status(400) - await self.finish({'message': 'Text colour missing'}) + await self.finish({"message": "Text colour missing"}) return - enable_background_colour: bool = data.get('enableBackgroundColour', False) - background_colour: str = data.get('backgroundColour', None) + enable_background_colour: bool = data.get( + "enableBackgroundColour", False + ) + background_colour: str = data.get("backgroundColour", None) if enable_background_colour and not background_colour: self.set_status(400) - await self.finish({'message': 'Background colour missing'}) + await self.finish({"message": "Background colour missing"}) return style.description = description @@ -1048,30 +1276,36 @@ async def patch(self): session.commit() self.set_status(200) - await self.finish({'message': 'Successfully edited stage direction style'}) + await self.finish( + {"message": "Successfully edited stage direction style"} + ) - await self.application.ws_send_to_all('NOOP', 'GET_STAGE_DIRECTION_STYLES', {}) + await self.application.ws_send_to_all( + "NOOP", "GET_STAGE_DIRECTION_STYLES", {} + ) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) @requires_show @no_live_session async def delete(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show: Show = session.query(Show).get(show_id) if show: - script: Script = session.query(Script).filter(Script.show_id == show.id).first() + script: Script = ( + session.query(Script).filter(Script.show_id == show.id).first() + ) self.requires_role(script, Role.WRITE) data = escape.json_decode(self.request.body) - style_id = data.get('id', None) + style_id = data.get("id", None) if not style_id: self.set_status(400) - await self.finish({'message': 'ID missing'}) + await self.finish({"message": "ID missing"}) return entry: StageDirectionStyle = session.get(StageDirectionStyle, style_id) @@ -1080,12 +1314,18 @@ async def delete(self): session.commit() self.set_status(200) - await self.finish({'message': 'Successfully deleted stage direction style'}) + await self.finish( + {"message": "Successfully deleted stage direction style"} + ) - await self.application.ws_send_to_all('NOOP', 'GET_STAGE_DIRECTION_STYLES', {}) + await self.application.ws_send_to_all( + "NOOP", "GET_STAGE_DIRECTION_STYLES", {} + ) else: self.set_status(404) - await self.finish({'message': '404 stage direction style not found'}) + await self.finish( + {"message": "404 stage direction style not found"} + ) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) diff --git a/server/controllers/api/show/sessions.py b/server/controllers/api/show/sessions.py index 77963adc..5eeb4f5a 100644 --- a/server/controllers/api/show/sessions.py +++ b/server/controllers/api/show/sessions.py @@ -2,49 +2,54 @@ from tornado import escape -from models.session import ShowSession, Session +from models.session import Session, ShowSession from models.show import Show from rbac.role import Role from schemas.schemas import ShowSessionSchema from utils.web.base_controller import BaseAPIController -from utils.web.web_decorators import requires_show from utils.web.route import ApiRoute, ApiVersion +from utils.web.web_decorators import requires_show -@ApiRoute('show/sessions', ApiVersion.V1) +@ApiRoute("show/sessions", ApiVersion.V1) class SessionsController(BaseAPIController): @requires_show def get(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] session_schema = ShowSessionSchema() with self.make_session() as session: show = session.query(Show).get(show_id) if show: - sessions = session.query(ShowSession).filter( - ShowSession.show_id == show.id).all() + sessions = ( + session.query(ShowSession) + .filter(ShowSession.show_id == show.id) + .all() + ) sessions = [session_schema.dump(s) for s in sessions] current_session = None if show.current_session_id: - current_session = session.query(ShowSession).get(show.current_session_id) + current_session = session.query(ShowSession).get( + show.current_session_id + ) current_session = session_schema.dump(current_session) self.set_status(200) - self.finish({'sessions': sessions, 'current_session': current_session}) + self.finish({"sessions": sessions, "current_session": current_session}) else: self.set_status(404) - self.finish({'message': '404 show not found'}) + self.finish({"message": "404 show not found"}) -@ApiRoute('show/sessions/start', ApiVersion.V1) +@ApiRoute("show/sessions/start", ApiVersion.V1) class SessionStartController(BaseAPIController): @requires_show async def post(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) @@ -52,20 +57,22 @@ async def post(self): self.requires_role(show, Role.EXECUTE) if show.current_session_id: self.set_status(409) - await self.finish({'message': '409 session already active'}) + await self.finish({"message": "409 session already active"}) else: data = escape.json_decode(self.request.body) - session_id = data.get('session_id', None) + session_id = data.get("session_id", None) if not session_id: self.set_status(400) - await self.finish({'message': 'session_id missing'}) + await self.finish({"message": "session_id missing"}) return user_session: Session = session.query(Session).get(session_id) if not user_session: self.set_status(400) - await self.finish({'message': 'Unable to find session given session_id'}) + await self.finish( + {"message": "Unable to find session given session_id"} + ) return show_session = ShowSession( @@ -73,7 +80,7 @@ async def post(self): start_date_time=datetime.utcnow(), end_date_time=None, client_internal_id=user_session.internal_id, - user_id=user_session.user.id + user_id=user_session.user.id, ) session.add(show_session) session.flush() @@ -82,21 +89,23 @@ async def post(self): session.commit() self.set_status(200) - self.write({'message': 'Successfully started show session'}) + self.write({"message": "Successfully started show session"}) - await self.application.ws_send_to_all('NOOP', 'GET_SHOW_SESSION_DATA', {}) - await self.application.ws_send_to_all('START_SHOW', 'NOOP', {}) + await self.application.ws_send_to_all( + "NOOP", "GET_SHOW_SESSION_DATA", {} + ) + await self.application.ws_send_to_all("START_SHOW", "NOOP", {}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) -@ApiRoute('show/sessions/stop', ApiVersion.V1) +@ApiRoute("show/sessions/stop", ApiVersion.V1) class SessionStopController(BaseAPIController): @requires_show async def post(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show = session.query(Show).get(show_id) @@ -104,19 +113,22 @@ async def post(self): self.requires_role(show, Role.EXECUTE) if not show.current_session_id: self.set_status(409) - await self.finish({'message': '409 no active session'}) + await self.finish({"message": "409 no active session"}) else: show_session: ShowSession = session.query(ShowSession).get( - show.current_session_id) + show.current_session_id + ) show_session.end_date_time = datetime.utcnow() show.current_session_id = None session.commit() self.set_status(200) - self.write({'message': 'Successfully stopped show session'}) + self.write({"message": "Successfully stopped show session"}) - await self.application.ws_send_to_all('NOOP', 'GET_SHOW_SESSION_DATA', {}) - await self.application.ws_send_to_all('STOP_SHOW', 'NOOP', {}) + await self.application.ws_send_to_all( + "NOOP", "GET_SHOW_SESSION_DATA", {} + ) + await self.application.ws_send_to_all("STOP_SHOW", "NOOP", {}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) diff --git a/server/controllers/api/show/shows.py b/server/controllers/api/show/shows.py index 9ca6b27d..4dce620d 100644 --- a/server/controllers/api/show/shows.py +++ b/server/controllers/api/show/shows.py @@ -1,18 +1,19 @@ from datetime import datetime + from dateutil import parser from tornado import escape, web +from digi_server.logger import get_logger from models.script import Script, ScriptRevision from models.show import Show from rbac.role import Role from schemas.schemas import ShowSchema from utils.web.base_controller import BaseAPIController -from utils.web.web_decorators import requires_show, require_admin from utils.web.route import ApiRoute, ApiVersion -from digi_server.logger import get_logger +from utils.web.web_decorators import require_admin, requires_show -@ApiRoute('show', ApiVersion.V1) +@ApiRoute("show", ApiVersion.V1) class ShowController(BaseAPIController): @web.authenticated @@ -22,20 +23,20 @@ async def post(self): Create a new show """ data = escape.json_decode(self.request.body) - get_logger().debug(f'New show data posted: {data}') + get_logger().debug(f"New show data posted: {data}") # Name - show_name = data.get('name', None) + show_name = data.get("name", None) if not show_name: self.set_status(400) - self.write({'message': 'Show name missing'}) + self.write({"message": "Show name missing"}) return # Start date - start_date = data.get('start', None) + start_date = data.get("start", None) if not start_date: self.set_status(400) - self.write({'message': 'Start date missing'}) + self.write({"message": "Start date missing"}) return try: start_date = parser.parse(start_date) @@ -43,14 +44,14 @@ async def post(self): raise Exception except BaseException: self.set_status(400) - self.write({'message': 'Unable to parse start date value'}) + self.write({"message": "Unable to parse start date value"}) return # End date - end_date = data.get('end', None) + end_date = data.get("end", None) if not end_date: self.set_status(400) - self.write({'message': 'End date missing'}) + self.write({"message": "End date missing"}) return try: end_date = parser.parse(end_date) @@ -58,21 +59,25 @@ async def post(self): raise Exception except BaseException: self.set_status(400) - self.write({'message': 'Unable to parse end date value'}) + self.write({"message": "Unable to parse end date value"}) return if start_date > end_date or end_date < start_date: self.set_status(400) - self.write({'message': 'Start date must be before or the same as the end date'}) + self.write( + {"message": "Start date must be before or the same as the end date"} + ) return with self.make_session() as session: now_time = datetime.utcnow() - show = Show(name=show_name, - start_date=start_date, - end_date=end_date, - created_at=now_time, - edited_at=now_time) + show = Show( + name=show_name, + start_date=start_date, + end_date=end_date, + created_at=now_time, + edited_at=now_time, + ) session.add(show) session.flush() @@ -82,11 +87,13 @@ async def post(self): session.flush() # Auto insert the first script revision - script_revision = ScriptRevision(script_id=script.id, - revision=1, - created_at=now_time, - edited_at=now_time, - description='Initial script revision') + script_revision = ScriptRevision( + script_id=script.id, + revision=1, + created_at=now_time, + edited_at=now_time, + description="Initial script revision", + ) session.add(script_revision) session.flush() @@ -95,18 +102,18 @@ async def post(self): session.commit() - should_load = bool(self.get_query_argument('load', default='False')) + should_load = bool(self.get_query_argument("load", default="False")) if should_load: - await self.application.digi_settings.set('current_show', show.id) + await self.application.digi_settings.set("current_show", show.id) self.set_status(200) - self.write({'message': 'Successfully created show'}) + self.write({"message": "Successfully created show"}) @requires_show def get(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] show_schema = ShowSchema() show = None @@ -121,12 +128,12 @@ def get(self): self.write(show) else: self.set_status(404) - self.write({'message': '404 show not found'}) + self.write({"message": "404 show not found"}) @requires_show async def patch(self): current_show = self.get_current_show() - show_id = current_show['id'] + show_id = current_show["id"] with self.make_session() as session: show: Show = session.query(Show).get(show_id) @@ -135,18 +142,18 @@ async def patch(self): data = escape.json_decode(self.request.body) # Name - show_name = data.get('name', None) + show_name = data.get("name", None) if not show_name: self.set_status(400) - self.write({'message': 'Show name missing'}) + self.write({"message": "Show name missing"}) return show.name = show_name # Start date - start_date = data.get('start_date', None) + start_date = data.get("start_date", None) if not start_date: self.set_status(400) - self.write({'message': 'Start date missing'}) + self.write({"message": "Start date missing"}) return try: start_date = parser.parse(start_date) @@ -154,14 +161,14 @@ async def patch(self): raise Exception except BaseException: self.set_status(400) - self.write({'message': 'Unable to parse start date value'}) + self.write({"message": "Unable to parse start date value"}) return # End date - end_date = data.get('end_date', None) + end_date = data.get("end_date", None) if not end_date: self.set_status(400) - self.write({'message': 'End date missing'}) + self.write({"message": "End date missing"}) return try: end_date = parser.parse(end_date) @@ -169,34 +176,36 @@ async def patch(self): raise Exception except BaseException: self.set_status(400) - self.write({'message': 'Unable to parse end date value'}) + self.write({"message": "Unable to parse end date value"}) return if start_date > end_date or end_date < start_date: self.set_status(400) - self.write({ - 'message': 'Start date must be before or the same as the end date' - }) + self.write( + { + "message": "Start date must be before or the same as the end date" + } + ) return show.start_date = start_date show.end_date = end_date # First act - show.first_act_id = data.get('first_act_id', None) + show.first_act_id = data.get("first_act_id", None) show.edited_at = datetime.utcnow() session.commit() self.set_status(200) - await self.finish({'message': 'Successfully updated act'}) + await self.finish({"message": "Successfully updated act"}) - await self.application.ws_send_to_all('NOOP', 'GET_SHOW_DETAILS', {}) + await self.application.ws_send_to_all("NOOP", "GET_SHOW_DETAILS", {}) else: self.set_status(404) - await self.finish({'message': '404 show not found'}) + await self.finish({"message": "404 show not found"}) -@ApiRoute('shows', ApiVersion.V1) +@ApiRoute("shows", ApiVersion.V1) class ShowsController(BaseAPIController): def get(self): @@ -207,4 +216,4 @@ def get(self): shows = [show_schema.dump(s) for s in shows] self.set_status(200) - self.write({'shows': shows}) + self.write({"shows": shows}) diff --git a/server/controllers/api/websocket.py b/server/controllers/api/websocket.py index f229d581..0bb54cda 100644 --- a/server/controllers/api/websocket.py +++ b/server/controllers/api/websocket.py @@ -4,7 +4,7 @@ from utils.web.route import ApiRoute, ApiVersion -@ApiRoute('ws/sessions', ApiVersion.V1) +@ApiRoute("ws/sessions", ApiVersion.V1) class WebsocketSessionsController(BaseAPIController): def get(self): @@ -14,4 +14,4 @@ def get(self): sessions = [session_scheme.dump(s) for s in sessions] self.set_status(200) - self.write({'sessions': sessions}) + self.write({"sessions": sessions}) diff --git a/server/controllers/controllers.py b/server/controllers/controllers.py index d36a11ee..9db7f268 100644 --- a/server/controllers/controllers.py +++ b/server/controllers/controllers.py @@ -7,18 +7,17 @@ from digi_server.logger import get_logger from utils.pkg_utils import find_end_modules -from utils.web.base_controller import BaseController, BaseAPIController +from utils.web.base_controller import BaseAPIController, BaseController from utils.web.route import ApiRoute, ApiVersion, Route - IMPORTED_CONTROLLERS = {} def import_all_controllers(): - controllers = find_end_modules('.', prefix='controllers') + controllers = find_end_modules(".", prefix="controllers") for controller in controllers: if controller != __name__: - get_logger().debug(f'Importing controller module {controller}') + get_logger().debug(f"Importing controller module {controller}") mod = importlib.import_module(controller) IMPORTED_CONTROLLERS[controller] = mod @@ -26,34 +25,33 @@ def import_all_controllers(): class RootController(BaseController): def get(self, path): file_path = os.path.join( - os.path.abspath( - os.path.dirname(__file__)), - "..", - "static") + os.path.abspath(os.path.dirname(__file__)), "..", "static" + ) full_path = os.path.join(file_path, "index.html") if not os.path.isfile(full_path): raise HTTPError(404) - with open(full_path, 'r') as file: + with open(full_path, "r") as file: self.write(file.read()) class StaticController(BaseController): def get(self): - self.set_header('Content-Type', '') + self.set_header("Content-Type", "") full_path = os.path.join( - os.path.abspath( - os.path.dirname(__file__)), "..", "static", url_unescape( - self.request.uri).strip( - os.path.sep)) + os.path.abspath(os.path.dirname(__file__)), + "..", + "static", + url_unescape(self.request.uri).strip(os.path.sep), + ) if not os.path.isfile(full_path): raise HTTPError(404) try: - with open(full_path, 'r') as file: + with open(full_path, "r") as file: self.write(file.read()) except UnicodeDecodeError: - with open(full_path, 'rb') as file: + with open(full_path, "rb") as file: self.write(file.read()) except Exception as exc: raise HTTPError(500) from exc @@ -62,40 +60,37 @@ def get(self): class ApiFallback(BaseAPIController): def get(self): self.set_status(404) - self.write({'message': '404 not found'}) + self.write({"message": "404 not found"}) def post(self): self.set_status(404) - self.write({'message': '404 not found'}) + self.write({"message": "404 not found"}) def patch(self): self.set_status(404) - self.write({'message': '404 not found'}) + self.write({"message": "404 not found"}) def delete(self): self.set_status(404) - self.write({'message': '404 not found'}) + self.write({"message": "404 not found"}) -@Route('/debug') +@Route("/debug") class DebugController(BaseController): def get(self): self.set_status(200) - self.set_header('Content-Type', 'application/json') - self.write({ - 'status': 'OK', - 'imported_controllers': list(IMPORTED_CONTROLLERS) - }) + self.set_header("Content-Type", "application/json") + self.write({"status": "OK", "imported_controllers": list(IMPORTED_CONTROLLERS)}) -@ApiRoute('debug', ApiVersion.V1) +@ApiRoute("debug", ApiVersion.V1) class ApiDebugController(BaseAPIController): def get(self): self.set_status(200) - self.set_header('Content-Type', 'application/json') - self.write({'status': 'OK', 'api_version': 1}) + self.set_header("Content-Type", "application/json") + self.write({"status": "OK", "api_version": 1}) -@Route('/debug/metrics', ignore_logging=True) +@Route("/debug/metrics", ignore_logging=True) class DebugMetricsController(MetricsHandler, BaseController): pass diff --git a/server/controllers/ws_controller.py b/server/controllers/ws_controller.py index 0197d204..c679b60b 100644 --- a/server/controllers/ws_controller.py +++ b/server/controllers/ws_controller.py @@ -1,11 +1,10 @@ import json -from typing import Optional, Awaitable, Union, TYPE_CHECKING, Dict, Any - +from typing import TYPE_CHECKING, Any, Awaitable, Dict, Optional, Union from uuid import uuid4 from tornado import gen from tornado.concurrent import Future -from tornado.websocket import WebSocketHandler, WebSocketClosedError +from tornado.websocket import WebSocketClosedError, WebSocketHandler from tornado_sqlalchemy import SessionMixin from digi_server.logger import get_logger @@ -17,60 +16,60 @@ from digi_server.app_server import DigiScriptServer -@ApiRoute('ws', ApiVersion.V1) +@ApiRoute("ws", ApiVersion.V1) class WebSocketController(SessionMixin, WebSocketHandler): def __init__(self, application, request, **kwargs): super().__init__(application, request, **kwargs) - self.application: DigiScriptServer = self.application # pylint: disable=used-before-assignment + self.application: DigiScriptServer = ( + self.application + ) # pylint: disable=used-before-assignment def update_session(self, is_editor=False, user_id=None): with self.make_session() as session: - entry = session.get(Session, self.__getattribute__('internal_id')) + entry = session.get(Session, self.__getattribute__("internal_id")) if entry: entry.last_ping = self.ws_connection.last_ping entry.last_pong = self.ws_connection.last_pong session.commit() else: - session.add(Session(internal_id=self.__getattribute__('internal_id'), - remote_ip=self.request.remote_ip, - last_ping=self.ws_connection.last_ping, - last_pong=self.ws_connection.last_pong, - user_id=user_id, - is_editor=is_editor)) + session.add( + Session( + internal_id=self.__getattribute__("internal_id"), + remote_ip=self.request.remote_ip, + last_ping=self.ws_connection.last_ping, + last_pong=self.ws_connection.last_pong, + user_id=user_id, + is_editor=is_editor, + ) + ) session.commit() def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: - raise RuntimeError( - f'Data streaming not supported for {self.__class__}') + raise RuntimeError(f"Data streaming not supported for {self.__class__}") def check_origin(self, origin): - if self.settings.get('debug', False): + if self.settings.get("debug", False): return True return super().check_origin(origin) @gen.coroutine def open(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]: - self.__setattr__('internal_id', str(uuid4())) + self.__setattr__("internal_id", str(uuid4())) self.application.clients.append(self) - user_id = self.get_secure_cookie('digiscript_user_id') + user_id = self.get_secure_cookie("digiscript_user_id") if user_id is not None: user_id = int(user_id) self.update_session(user_id=user_id) - get_logger().info(f'WebSocket opened from: {self.request.remote_ip}') + get_logger().info(f"WebSocket opened from: {self.request.remote_ip}") - yield self.write_message({ - 'OP': 'SET_UUID', - 'DATA': self.__getattribute__('internal_id') - }) + yield self.write_message( + {"OP": "SET_UUID", "DATA": self.__getattribute__("internal_id")} + ) - yield self.write_message({ - 'OP': 'NOOP', - 'DATA': {}, - 'ACTION': 'GET_SETTINGS' - }) + yield self.write_message({"OP": "NOOP", "DATA": {}, "ACTION": "GET_SETTINGS"}) def on_close(self) -> None: if self in self.application.clients: @@ -80,7 +79,7 @@ def on_close(self) -> None: elect_live_leader = False with self.make_session() as session: - entry = session.get(Session, self.__getattribute__('internal_id')) + entry = session.get(Session, self.__getattribute__("internal_id")) if entry: if entry.is_editor: notify_editor_change = True @@ -92,179 +91,217 @@ def on_close(self) -> None: if notify_editor_change: for client in self.application.clients: - client.write_message({ - 'OP': 'NOOP', - 'ACTION': 'GET_SCRIPT_CONFIG_STATUS', - 'DATA': {} - }) + client.write_message( + {"OP": "NOOP", "ACTION": "GET_SCRIPT_CONFIG_STATUS", "DATA": {}} + ) if elect_live_leader: - current_show = self.application.digi_settings.settings.get('current_show').get_value() + current_show = self.application.digi_settings.settings.get( + "current_show" + ).get_value() if current_show: with self.make_session() as session: show = session.query(Show).get(current_show) if show.current_session_id: - live_session: ShowSession = session.query(ShowSession).get(show.current_session_id) - live_session.last_client_internal_id = self.__getattribute__('internal_id') + live_session: ShowSession = session.query(ShowSession).get( + show.current_session_id + ) + live_session.last_client_internal_id = self.__getattribute__( + "internal_id" + ) session.flush() - next_session: Session = session.query(Session).filter( - Session.user_id == live_session.user_id).first() + next_session: Session = ( + session.query(Session) + .filter(Session.user_id == live_session.user_id) + .first() + ) if next_session: next_ws = self.application.get_ws(next_session.internal_id) if not next_ws: - get_logger().error('Unable to elect new leader of live session') + get_logger().error( + "Unable to elect new leader of live session" + ) else: - live_session.client_internal_id = next_session.internal_id + live_session.client_internal_id = ( + next_session.internal_id + ) live_session.last_client_internal_id = None - next_ws.write_message({ - 'OP': 'NOOP', - 'ACTION': 'ELECTED_LEADER', - 'DATA': { - 'latest_line_ref': live_session.latest_line_ref + next_ws.write_message( + { + "OP": "NOOP", + "ACTION": "ELECTED_LEADER", + "DATA": { + "latest_line_ref": live_session.latest_line_ref + }, } - }) + ) else: for client in self.application.clients: - client.write_message({ - 'OP': 'NOOP', - 'ACTION': 'NO_LEADER', - 'DATA': {} - }) + client.write_message( + {"OP": "NOOP", "ACTION": "NO_LEADER", "DATA": {}} + ) session.commit() for client in self.application.clients: - client.write_message({ - 'OP': 'NOOP', - 'ACTION': 'GET_SHOW_SESSION_DATA', - 'DATA': {} - }) + client.write_message( + { + "OP": "NOOP", + "ACTION": "GET_SHOW_SESSION_DATA", + "DATA": {}, + } + ) - get_logger().info(f'WebSocket closed from: {self.request.remote_ip}') + get_logger().info(f"WebSocket closed from: {self.request.remote_ip}") - async def on_message(self, message: Union[str, bytes]): # pylint: disable=invalid-overridden-method + async def on_message( + self, message: Union[str, bytes] + ): # pylint: disable=invalid-overridden-method get_logger().debug( - f'WebSocket received data from {self.request.remote_ip}: {message}') + f"WebSocket received data from {self.request.remote_ip}: {message}" + ) message = json.loads(message) - ws_op = message['OP'] + ws_op = message["OP"] with self.make_session() as session: - entry: Session = session.get(Session, self.__getattribute__('internal_id')) - current_show = await self.application.digi_settings.get('current_show') + entry: Session = session.get(Session, self.__getattribute__("internal_id")) + current_show = await self.application.digi_settings.get("current_show") if current_show: show = session.query(Show).get(current_show) else: show = None show_session: Optional[ShowSession] = None - user_id = self.get_secure_cookie('digiscript_user_id') + user_id = self.get_secure_cookie("digiscript_user_id") if user_id is not None: user_id = int(user_id) - if ws_op == 'NEW_CLIENT': + if ws_op == "NEW_CLIENT": if user_id and show and show.current_session_id: - show_session = session.query(ShowSession).get(show.current_session_id) + show_session = session.query(ShowSession).get( + show.current_session_id + ) if show_session and not show_session.client_internal_id: if show_session.user_id == user_id: - show_session.client_internal_id = self.__getattribute__('internal_id') + show_session.client_internal_id = self.__getattribute__( + "internal_id" + ) session.commit() - await self.write_message({ - 'OP': 'NOOP', - 'ACTION': 'ELECTED_LEADER', - 'DATA': { - 'latest_line_ref': show_session.latest_line_ref + await self.write_message( + { + "OP": "NOOP", + "ACTION": "ELECTED_LEADER", + "DATA": { + "latest_line_ref": show_session.latest_line_ref + }, } - }) + ) await self.application.ws_send_to_all( - 'NOOP', - 'GET_SHOW_SESSION_DATA', - {} + "NOOP", "GET_SHOW_SESSION_DATA", {} ) - elif ws_op == 'REFRESH_CLIENT': - new_uuid = message['DATA'] + elif ws_op == "REFRESH_CLIENT": + new_uuid = message["DATA"] is_editor = False update_session_client = False if entry: is_editor = entry.is_editor if show and show.current_session_id: - show_session = session.query(ShowSession).get(show.current_session_id) - if show_session and show_session.last_client_internal_id == new_uuid: + show_session = session.query(ShowSession).get( + show.current_session_id + ) + if ( + show_session + and show_session.last_client_internal_id == new_uuid + ): update_session_client = True session.delete(entry) session.commit() - self.__setattr__('internal_id', new_uuid) + self.__setattr__("internal_id", new_uuid) self.update_session(is_editor=is_editor, user_id=user_id) if update_session_client: show_session.client_internal_id = new_uuid show_session.last_client_internal_id = None session.commit() await self.application.ws_send_to_all( - 'NOOP', - 'GET_SHOW_SESSION_DATA', - {} + "NOOP", "GET_SHOW_SESSION_DATA", {} ) - elif ws_op == 'REQUEST_SCRIPT_EDIT': + elif ws_op == "REQUEST_SCRIPT_EDIT": editors = session.query(Session).filter(Session.is_editor).all() if len(editors) == 0: entry.is_editor = True session.commit() - await self.application.ws_send_to_all('NOOP', 'GET_SCRIPT_CONFIG_STATUS', {}) + await self.application.ws_send_to_all( + "NOOP", "GET_SCRIPT_CONFIG_STATUS", {} + ) else: - await self.write_message({ - 'OP': 'NOOP', - 'ACTION': 'REQUEST_EDIT_FAILURE', - 'DATA': {} - }) - elif ws_op == 'STOP_SCRIPT_EDIT': + await self.write_message( + {"OP": "NOOP", "ACTION": "REQUEST_EDIT_FAILURE", "DATA": {}} + ) + elif ws_op == "STOP_SCRIPT_EDIT": if entry.is_editor: entry.is_editor = False session.commit() - await self.application.ws_send_to_all('NOOP', 'GET_SCRIPT_CONFIG_STATUS', {}) - elif ws_op == 'SCRIPT_SCROLL': + await self.application.ws_send_to_all( + "NOOP", "GET_SCRIPT_CONFIG_STATUS", {} + ) + elif ws_op == "SCRIPT_SCROLL": if show and show.current_session_id: - show_session = session.query(ShowSession).get(show.current_session_id) + show_session = session.query(ShowSession).get( + show.current_session_id + ) if show_session: - if show_session.client_internal_id == self.__getattribute__('internal_id'): - show_session.latest_line_ref = message['DATA']['current_line'] + if show_session.client_internal_id == self.__getattribute__( + "internal_id" + ): + show_session.latest_line_ref = message["DATA"][ + "current_line" + ] session.commit() await self.application.ws_send_to_all( - 'NOOP', - 'SCRIPT_SCROLL', - message['DATA'] + "NOOP", "SCRIPT_SCROLL", message["DATA"] ) - elif ws_op == 'RELOAD_CLIENTS': + elif ws_op == "RELOAD_CLIENTS": if show and show.current_session_id: - show_session = session.query(ShowSession).get(show.current_session_id) - if show_session and show_session.client_internal_id == self.__getattribute__('internal_id'): + show_session = session.query(ShowSession).get( + show.current_session_id + ) + if ( + show_session + and show_session.client_internal_id + == self.__getattribute__("internal_id") + ): await self.application.ws_send_to_all( - 'RELOAD_CLIENT', - 'NOOP', - {} + "RELOAD_CLIENT", "NOOP", {} ) else: - get_logger().warning(f'Unknown OP {ws_op} received from ' - f'WebSocket connection {self.request.remote_ip}') + get_logger().warning( + f"Unknown OP {ws_op} received from " + f"WebSocket connection {self.request.remote_ip}" + ) def on_pong(self, data: bytes) -> None: self.update_session() get_logger().trace( - f'Ping response from {self.request.remote_ip} : {data.hex()}') + f"Ping response from {self.request.remote_ip} : {data.hex()}" + ) def on_ping(self, data: bytes) -> None: self.update_session() - get_logger().trace( - f'Ping from {self.request.remote_ip} : {data.hex()}') + get_logger().trace(f"Ping from {self.request.remote_ip} : {data.hex()}") @gen.coroutine - def write_message(self, message: Union[bytes, str, Dict[str, Any]], - binary: bool = False) -> Future[None]: + def write_message( + self, message: Union[bytes, str, Dict[str, Any]], binary: bool = False + ) -> Future[None]: try: return super().write_message(message, binary) except WebSocketClosedError: - get_logger().error(f'Trying to send message to closed websocket ' - f'{self.__getattribute__("internal_id")} at IP address ' - f'{self.request.remote_ip}, closing.') + get_logger().error( + f"Trying to send message to closed websocket " + f'{self.__getattribute__("internal_id")} at IP address ' + f"{self.request.remote_ip}, closing." + ) self.on_close() return None diff --git a/server/digi_server/app_server.py b/server/digi_server/app_server.py index 285f8655..ef811a69 100644 --- a/server/digi_server/app_server.py +++ b/server/digi_server/app_server.py @@ -13,13 +13,13 @@ from controllers import controllers from controllers.ws_controller import WebSocketController -from digi_server.logger import get_logger, configure_file_logging, configure_db_logging +from digi_server.logger import configure_db_logging, configure_file_logging, get_logger from digi_server.settings import Settings from models import models from models.cue import CueType from models.script import Script -from models.show import Show from models.session import Session, ShowSession +from models.show import Show from models.user import User from rbac.rbac import RBACController from utils.database import DigiSQLAlchemy @@ -30,7 +30,13 @@ class DigiScriptServer(PrometheusMixIn, Application): - def __init__(self, debug=False, settings_path=None, skip_migrations=False, skip_migrations_check=False): + def __init__( + self, + debug=False, + settings_path=None, + skip_migrations=False, + skip_migrations_check=False, + ): self.env_parser: EnvParser = EnvParser.instance() # pylint: disable=no-member self.digi_settings: Settings = Settings(self, settings_path) @@ -49,15 +55,15 @@ def __init__(self, debug=False, settings_path=None, skip_migrations=False, skip_ if not skip_migrations: self._run_migrations() else: - get_logger().warning('Skipping performing database migrations') + get_logger().warning("Skipping performing database migrations") # And then check the database is up-to-date if not skip_migrations_check: self._check_migrations() else: - get_logger().warning('Skipping database migrations check') + get_logger().warning("Skipping database migrations check") # Finally, configure the database - db_path = self.digi_settings.settings.get('db_path').get_value() - get_logger().info(f'Using {db_path} as DB path') + db_path = self.digi_settings.settings.get("db_path").get_value() + get_logger().info(f"Using {db_path} as DB path") self._db.configure(url=db_path) self.rbac = RBACController(self) @@ -67,7 +73,7 @@ def __init__(self, debug=False, settings_path=None, skip_migrations=False, skip_ # Clear out all sessions since we are starting the app up with self._db.sessionmaker() as session: - get_logger().debug('Emptying out sessions table!') + get_logger().debug("Emptying out sessions table!") session.query(Session).delete() session.commit() @@ -75,50 +81,61 @@ def __init__(self, debug=False, settings_path=None, skip_migrations=False, skip_ with self._db.sessionmaker() as session: any_admin = session.query(User).filter(User.is_admin).first() has_admin = any_admin is not None - self.digi_settings.settings['has_admin_user'].set_value(has_admin, False) + self.digi_settings.settings["has_admin_user"].set_value(has_admin, False) self.digi_settings._save() # Check the show we are expecting to be loaded exists with self._db.sessionmaker() as session: - current_show = self.digi_settings.settings.get('current_show').get_value() + current_show = self.digi_settings.settings.get("current_show").get_value() if current_show: show = session.query(Show).get(current_show) if not show: - get_logger().warning('Current show from settings not found. Resetting.') - self.digi_settings.settings['current_show'].set_to_default() + get_logger().warning( + "Current show from settings not found. Resetting." + ) + self.digi_settings.settings["current_show"].set_to_default() self.digi_settings._save() # If there is a live session in progress, clean up the current client ID with self._db.sessionmaker() as session: - current_show = self.digi_settings.settings.get('current_show').get_value() + current_show = self.digi_settings.settings.get("current_show").get_value() if current_show: show = session.query(Show).get(current_show) if show and show.current_session_id: - show_session: ShowSession = session.query(ShowSession).get(show.current_session_id) + show_session: ShowSession = session.query(ShowSession).get( + show.current_session_id + ) if show_session: - show_session.last_client_internal_id = show_session.client_internal_id + show_session.last_client_internal_id = ( + show_session.client_internal_id + ) show_session.client_internal_id = None else: - get_logger().warning('Current show session not found. Resetting.') + get_logger().warning( + "Current show session not found. Resetting." + ) show.current_session_id = None session.commit() - static_files_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), - '..', 'static', 'assets') - get_logger().info(f'Using {static_files_path} as static files path') + static_files_path = os.path.join( + os.path.abspath(os.path.dirname(__file__)), "..", "static", "assets" + ) + get_logger().info(f"Using {static_files_path} as static files path") handlers = Route.routes() - handlers.append(('/favicon.ico', controllers.StaticController)) - handlers.append((r'/assets/(.*)', StaticFileHandler, {'path': static_files_path})) - handlers.append((r'/api/.*', controllers.ApiFallback)) - handlers.append((r'/(.*)', controllers.RootController)) + handlers.append(("/favicon.ico", controllers.StaticController)) + handlers.append( + (r"/assets/(.*)", StaticFileHandler, {"path": static_files_path}) + ) + handlers.append((r"/api/.*", controllers.ApiFallback)) + handlers.append((r"/(.*)", controllers.RootController)) super().__init__( handlers=handlers, debug=debug, db=self._db, websocket_ping_interval=5, - cookie_secret='DigiScriptSuperSecretValue123!', - login_url='/login', + cookie_secret="DigiScriptSuperSecretValue123!", + login_url="/login", ) def log_request(self, handler): @@ -129,67 +146,78 @@ def log_request(self, handler): @property def _alembic_config(self): - alembic_cfg_path = os.path.join(os.path.dirname(__file__), '..', 'alembic.ini') + alembic_cfg_path = os.path.join(os.path.dirname(__file__), "..", "alembic.ini") alembic_cfg = Config(alembic_cfg_path) # Override config options with specific ones based on this running instance - alembic_cfg.set_main_option('digiscript.config', self.digi_settings.settings_path) - alembic_cfg.set_main_option('configure_logging', 'False') + alembic_cfg.set_main_option( + "digiscript.config", self.digi_settings.settings_path + ) + alembic_cfg.set_main_option("configure_logging", "False") return alembic_cfg def _run_migrations(self): try: self._check_migrations() except DatabaseUpgradeRequired: - get_logger().info('Running database migrations via Alembic') + get_logger().info("Running database migrations via Alembic") # Create a copy of the database file as a backup before performing migrations - db_path: str = self.digi_settings.settings.get('db_path').get_value() - if db_path.startswith('sqlite:///'): - db_path = db_path.replace('sqlite:///', '') + db_path: str = self.digi_settings.settings.get("db_path").get_value() + if db_path.startswith("sqlite:///"): + db_path = db_path.replace("sqlite:///", "") if os.path.exists(db_path) and os.path.isfile(db_path): - get_logger().info('Creating copy of database file as backup') - new_file_name = f'{db_path}.{int(time.time())}' + get_logger().info("Creating copy of database file as backup") + new_file_name = f"{db_path}.{int(time.time())}" shutil.copyfile(db_path, new_file_name) - get_logger().info(f'Created copy of database file as backup, saved to {new_file_name}') + get_logger().info( + f"Created copy of database file as backup, saved to {new_file_name}" + ) else: - get_logger().warning('Database connection does not appear to be a file, cannot create backup!') + get_logger().warning( + "Database connection does not appear to be a file, cannot create backup!" + ) # Run the upgrade on the database - command.upgrade(self._alembic_config, 'head') + command.upgrade(self._alembic_config, "head") else: - get_logger().info('No database migrations to perform') + get_logger().info("No database migrations to perform") def _check_migrations(self): - get_logger().info('Checking database migrations via Alembic') - engine = sqlalchemy.create_engine(self.digi_settings.settings.get('db_path').get_value()) + get_logger().info("Checking database migrations via Alembic") + engine = sqlalchemy.create_engine( + self.digi_settings.settings.get("db_path").get_value() + ) script_ = script.ScriptDirectory.from_config(self._alembic_config) with engine.begin() as conn: context = migration.MigrationContext.configure(conn) if context.get_current_revision() != script_.get_current_head(): - raise DatabaseUpgradeRequired('Migrations required on the database') + raise DatabaseUpgradeRequired("Migrations required on the database") async def configure(self): await self._configure_logging() async def _configure_logging(self): - get_logger().info('Reconfiguring logging!') + get_logger().info("Reconfiguring logging!") # Application logging - log_path = await self.digi_settings.get('log_path') - file_size = await self.digi_settings.get('max_log_mb') - backups = await self.digi_settings.get('log_backups') + log_path = await self.digi_settings.get("log_path") + file_size = await self.digi_settings.get("max_log_mb") + backups = await self.digi_settings.get("log_backups") if log_path: - self.app_log_handler = configure_file_logging(log_path, file_size, backups, - self.app_log_handler) + self.app_log_handler = configure_file_logging( + log_path, file_size, backups, self.app_log_handler + ) # Database logging - use_db_logging = await self.digi_settings.get('db_log_enabled') + use_db_logging = await self.digi_settings.get("db_log_enabled") if use_db_logging: - db_log_path = await self.digi_settings.get('db_log_path') - db_file_size = await self.digi_settings.get('db_max_log_mb') - db_backups = await self.digi_settings.get('db_log_backups') - self.db_file_handler = configure_db_logging(log_path=db_log_path, - max_size_mb=db_file_size, - log_backups=db_backups, - handler=self.db_file_handler) + db_log_path = await self.digi_settings.get("db_log_path") + db_file_size = await self.digi_settings.get("db_max_log_mb") + db_backups = await self.digi_settings.get("db_log_backups") + self.db_file_handler = configure_db_logging( + log_path=db_log_path, + max_size_mb=db_file_size, + log_backups=db_backups, + handler=self.db_file_handler, + ) def _configure_rbac(self): self.rbac.add_mapping(User, Show, [Show.id, Show.name]) @@ -198,19 +226,25 @@ def _configure_rbac(self): def regen_logging(self): if not IOLoop.current(): - get_logger().error('Unable to regenerate logging as there is no current IOLoop') + get_logger().error( + "Unable to regenerate logging as there is no current IOLoop" + ) else: IOLoop.current().add_callback(self._configure_logging) def validate_has_admin(self): if not IOLoop.current(): - get_logger().error('Unable to validate admin user as there is no current IOLoop') + get_logger().error( + "Unable to validate admin user as there is no current IOLoop" + ) else: IOLoop.current().add_callback(self._validate_has_admin) def show_changed(self): if not IOLoop.current(): - get_logger().error('Unable to initiate show change as there is no current IOLoop') + get_logger().error( + "Unable to initiate show change as there is no current IOLoop" + ) else: IOLoop.current().add_callback(self._show_changed) @@ -218,10 +252,10 @@ async def _validate_has_admin(self): with self.get_db().sessionmaker() as session: any_admin = session.query(User).filter(User.is_admin).first() has_admin = any_admin is not None - await self.digi_settings.set('has_admin_user', has_admin) + await self.digi_settings.set("has_admin_user", has_admin) async def _show_changed(self): - await self.ws_send_to_all('NOOP', 'SHOW_CHANGED', {}) + await self.ws_send_to_all("NOOP", "SHOW_CHANGED", {}) def get_db(self) -> DigiSQLAlchemy: return self._db @@ -229,21 +263,19 @@ def get_db(self) -> DigiSQLAlchemy: def get_all_ws(self, user_id) -> List[WebSocketController]: sockets = [] for client in self.clients: - c_user_id = client.get_secure_cookie('digiscript_user_id') + c_user_id = client.get_secure_cookie("digiscript_user_id") if c_user_id is not None and int(c_user_id) == user_id: sockets.append(client) return sockets def get_ws(self, internal_uuid: str) -> Optional[WebSocketController]: for client in self.clients: - if client.__getattribute__('internal_id') == internal_uuid: + if client.__getattribute__("internal_id") == internal_uuid: return client return None async def ws_send_to_all(self, ws_op: str, ws_action: str, ws_data: dict): for client in self.clients: - await client.write_message({ - 'OP': ws_op, - 'DATA': ws_data, - 'ACTION': ws_action - }) + await client.write_message( + {"OP": ws_op, "DATA": ws_data, "ACTION": ws_action} + ) diff --git a/server/digi_server/logger.py b/server/digi_server/logger.py index aa20ad8a..cb52c6a4 100644 --- a/server/digi_server/logger.py +++ b/server/digi_server/logger.py @@ -1,10 +1,9 @@ import logging from logging.handlers import RotatingFileHandler - from tornado.log import LogFormatter -logger = logging.getLogger('DigiScript') +logger = logging.getLogger("DigiScript") def get_logger(): @@ -12,27 +11,28 @@ def get_logger(): def configure_file_logging(log_path, max_size_mb=100, log_backups=5, handler=None): - size_bytes = max_size_mb*1024*1024 + size_bytes = max_size_mb * 1024 * 1024 app_logger = get_logger() if handler: app_logger.removeHandler(handler) - file_handler = RotatingFileHandler(log_path, - maxBytes=size_bytes, - backupCount=log_backups) + file_handler = RotatingFileHandler( + log_path, maxBytes=size_bytes, backupCount=log_backups + ) file_handler.setFormatter(LogFormatter(color=False)) app_logger.addHandler(file_handler) - logging.getLogger('tornado.access').addHandler(file_handler) - logging.getLogger('tornado.application').addHandler(file_handler) - logging.getLogger('tornado.general').addHandler(file_handler) + logging.getLogger("tornado.access").addHandler(file_handler) + logging.getLogger("tornado.application").addHandler(file_handler) + logging.getLogger("tornado.general").addHandler(file_handler) return file_handler -def configure_db_logging(log_level=logging.DEBUG, log_path=None, max_size_mb=100, log_backups=5, - handler=None): - size_bytes = max_size_mb*1024*1024 - db_logger = logging.getLogger('sqlalchemy.engine') +def configure_db_logging( + log_level=logging.DEBUG, log_path=None, max_size_mb=100, log_backups=5, handler=None +): + size_bytes = max_size_mb * 1024 * 1024 + db_logger = logging.getLogger("sqlalchemy.engine") if handler: db_logger.removeHandler(handler) @@ -40,9 +40,9 @@ def configure_db_logging(log_level=logging.DEBUG, log_path=None, max_size_mb=100 db_logger.setLevel(log_level) file_handler = None if log_path: - file_handler = RotatingFileHandler(log_path, - maxBytes=size_bytes, - backupCount=log_backups) + file_handler = RotatingFileHandler( + log_path, maxBytes=size_bytes, backupCount=log_backups + ) file_handler.setFormatter(LogFormatter(color=False)) db_logger.addHandler(file_handler) db_logger.propagate = False @@ -54,15 +54,17 @@ def add_logging_level(level_name, level_num, method_name=None): method_name = level_name.lower() if hasattr(logging, level_name): - raise AttributeError(f'{level_name} already defined in logging module') + raise AttributeError(f"{level_name} already defined in logging module") if hasattr(logging, method_name): - raise AttributeError(f'{method_name} already defined in logging module') + raise AttributeError(f"{method_name} already defined in logging module") if hasattr(logging.getLoggerClass(), method_name): - raise AttributeError(f'{method_name} already defined in logger class') + raise AttributeError(f"{method_name} already defined in logger class") def log_for_level(self, message, *args, **kwargs): if self.isEnabledFor(level_num): - self._log(level_num, message, args, **kwargs) # pylint: disable=protected-access + self._log( + level_num, message, args, **kwargs + ) # pylint: disable=protected-access def log_to_root(message, *args, **kwargs): logging.log(level_num, message, *args, **kwargs) diff --git a/server/digi_server/settings.py b/server/digi_server/settings.py index 7a66e624..49270d18 100644 --- a/server/digi_server/settings.py +++ b/server/digi_server/settings.py @@ -14,11 +14,22 @@ class SettingsObject: # pylint: disable=too-many-instance-attributes - def __init__(self, key, val_type, default, can_edit=True, callback_fn=None, nullable=False, - display_name: str = "", help_text: str = ""): + def __init__( + self, + key, + val_type, + default, + can_edit=True, + callback_fn=None, + nullable=False, + display_name: str = "", + help_text: str = "", + ): if val_type not in [str, bool, int]: - raise RuntimeError(f'Invalid type {val_type} for {key}. Allowed options are: ' - f'[str, int, bool]') + raise RuntimeError( + f"Invalid type {val_type} for {key}. Allowed options are: " + f"[str, int, bool]" + ) self.key = key self.val_type = val_type @@ -38,10 +49,14 @@ def set_to_default(self): def set_value(self, value, spawn_callbacks=True): if not isinstance(value, self.val_type): if value is None and not self._nullable: - raise RuntimeError(f'Value for {self.key} cannot be None (is not nullable)') + raise RuntimeError( + f"Value for {self.key} cannot be None (is not nullable)" + ) if value is not None: - raise TypeError(f'Value for {self.key} of {value} is not of assigned ' - f'type {self.val_type}') + raise TypeError( + f"Value for {self.key} of {value} is not of assigned " + f"type {self.val_type}" + ) changed = False if value != self.value: @@ -61,12 +76,12 @@ def is_loaded(self): def as_json(self): return { - 'type': self.val_type.__name__, - 'value': self.value, - 'default': self.default, - 'can_edit': self.can_edit, - 'display_name': self.display_name, - 'help_text': self.help_text, + "type": self.val_type.__name__, + "value": self.value, + "default": self.default, + "can_edit": self.can_edit, + "display_name": self.display_name, + "help_text": self.help_text, } @@ -74,70 +89,165 @@ class Settings: def __init__(self, application: DigiScriptServer, settings_path=None): self._application = application self.lock = Lock() - self._base_path = os.path.join(os.path.dirname(__file__), '..', 'conf') + self._base_path = os.path.join(os.path.dirname(__file__), "..", "conf") if not os.path.exists(self._base_path): - get_logger().info(f'Creating base path {self._base_path}') + get_logger().info(f"Creating base path {self._base_path}") os.makedirs(self._base_path) if settings_path: self.settings_path = settings_path else: - self.settings_path = os.path.join(self._base_path, 'digiscript.json') - get_logger().info( - f'No settings path provided, using {self.settings_path}') + self.settings_path = os.path.join(self._base_path, "digiscript.json") + get_logger().info(f"No settings path provided, using {self.settings_path}") if not os.path.exists(os.path.dirname(self.settings_path)): - get_logger().info(f'Creating settings path {os.path.dirname(self.settings_path)}') + get_logger().info( + f"Creating settings path {os.path.dirname(self.settings_path)}" + ) os.makedirs(os.path.dirname(self.settings_path)) self.settings = {} db_default = f'sqlite:///{os.path.join(os.path.dirname(__file__), "../conf/digiscript.sqlite")}' - self.define('has_admin_user', bool, False, False, nullable=False, - callback_fn=self._application.validate_has_admin, display_name="Has Admin User") - self.define('db_path', str, db_default, False, nullable=False, display_name="Database Path") - self.define('current_show', int, None, False, nullable=True, - callback_fn=self._application.show_changed, display_name="Current Show ID") - self.define('debug_mode', bool, False, True, display_name="Enable Debug Mode") - self.define('log_path', str, os.path.join(self._base_path, 'digiscript.log'), True, - self._application.regen_logging, display_name="Application Log Path") - self.define('max_log_mb', int, 100, True, self._application.regen_logging, display_name="Max Log Size (MB)") - self.define('log_backups', int, 5, True, self._application.regen_logging, display_name="Log Backups") - self.define('db_log_enabled', bool, False, True, self._application.regen_logging, - display_name="Enable Database Log") - self.define('db_log_path', str, os.path.join(self._base_path, 'digiscript_db.log'), True, - self._application.regen_logging, display_name="Database Log Path") - self.define('db_max_log_mb', int, 100, True, self._application.regen_logging, - display_name="Max Database Log Size (MB)") - self.define('db_log_backups', int, 5, True, self._application.regen_logging, - display_name="Database Log Backups") - self.define('enable_lazy_loading', bool, True, True, - display_name="Enable Lazy Loading", - help_text="Whether the client side should load all script pages initially when connected " - "to a live show") - self.define('enable_live_batching', bool, True, True, - display_name="Enable Live Batching", - help_text="Whether the live show page should only display a subsection of the script pages " - "at a time") + self.define( + "has_admin_user", + bool, + False, + False, + nullable=False, + callback_fn=self._application.validate_has_admin, + display_name="Has Admin User", + ) + self.define( + "db_path", + str, + db_default, + False, + nullable=False, + display_name="Database Path", + ) + self.define( + "current_show", + int, + None, + False, + nullable=True, + callback_fn=self._application.show_changed, + display_name="Current Show ID", + ) + self.define("debug_mode", bool, False, True, display_name="Enable Debug Mode") + self.define( + "log_path", + str, + os.path.join(self._base_path, "digiscript.log"), + True, + self._application.regen_logging, + display_name="Application Log Path", + ) + self.define( + "max_log_mb", + int, + 100, + True, + self._application.regen_logging, + display_name="Max Log Size (MB)", + ) + self.define( + "log_backups", + int, + 5, + True, + self._application.regen_logging, + display_name="Log Backups", + ) + self.define( + "db_log_enabled", + bool, + False, + True, + self._application.regen_logging, + display_name="Enable Database Log", + ) + self.define( + "db_log_path", + str, + os.path.join(self._base_path, "digiscript_db.log"), + True, + self._application.regen_logging, + display_name="Database Log Path", + ) + self.define( + "db_max_log_mb", + int, + 100, + True, + self._application.regen_logging, + display_name="Max Database Log Size (MB)", + ) + self.define( + "db_log_backups", + int, + 5, + True, + self._application.regen_logging, + display_name="Database Log Backups", + ) + self.define( + "enable_lazy_loading", + bool, + True, + True, + display_name="Enable Lazy Loading", + help_text="Whether the client side should load all script pages initially when connected " + "to a live show", + ) + self.define( + "enable_live_batching", + bool, + True, + True, + display_name="Enable Live Batching", + help_text="Whether the live show page should only display a subsection of the script pages " + "at a time", + ) self._load(spawn_callbacks=False) - self._file_watcher = IOLoopFileWatcher(self.settings_path, self.auto_reload_changes, 100) + self._file_watcher = IOLoopFileWatcher( + self.settings_path, self.auto_reload_changes, 100 + ) self._file_watcher.add_error_callback(self.file_deleted) self._file_watcher.watch() - def define(self, key, val_type, default, can_edit, callback_fn=None, nullable=False, - display_name: str = "", help_text: str = ""): - self.settings[key] = SettingsObject(key, val_type, default, can_edit, callback_fn, - nullable, display_name, help_text) + def define( + self, + key, + val_type, + default, + can_edit, + callback_fn=None, + nullable=False, + display_name: str = "", + help_text: str = "", + ): + self.settings[key] = SettingsObject( + key, + val_type, + default, + can_edit, + callback_fn, + nullable, + display_name, + help_text, + ) def file_deleted(self): - get_logger().info('Settings file deleted; recreating from in memory settings') + get_logger().info("Settings file deleted; recreating from in memory settings") self._save() def auto_reload_changes(self): - get_logger().info('Settings file changed; auto reloading') + get_logger().info("Settings file changed; auto reloading") self._load(spawn_callbacks=True) settings_json = {} @@ -145,21 +255,25 @@ def auto_reload_changes(self): settings_json[key] = value.get_value() for client in self._application.clients: - client.write_message({ - 'OP': 'SETTINGS_CHANGED', - 'DATA': settings_json, - 'ACTION': 'WS_SETTINGS_CHANGED' - }) + client.write_message( + { + "OP": "SETTINGS_CHANGED", + "DATA": settings_json, + "ACTION": "WS_SETTINGS_CHANGED", + } + ) def _load(self, spawn_callbacks=False): if os.path.exists(self.settings_path): # Read in the settings - with open(self.settings_path, 'r', encoding='UTF-8') as file_pointer: + with open(self.settings_path, "r", encoding="UTF-8") as file_pointer: settings = json.load(file_pointer) for key, value in settings.items(): if key not in self.settings: - get_logger().warning(f'Setting {key} found in settings file is not ' - f'defined, ignoring!') + get_logger().warning( + f"Setting {key} found in settings file is not " + f"defined, ignoring!" + ) else: self.settings[key].set_value(value, spawn_callbacks) @@ -170,51 +284,54 @@ def _load(self, spawn_callbacks=False): value.set_to_default() needs_saving = True if needs_saving: - with open(self.settings_path, 'w', encoding='UTF-8') as file_pointer: + with open(self.settings_path, "w", encoding="UTF-8") as file_pointer: json.dump(self._json(), file_pointer, indent=4) - get_logger().info(f'Saved settings to {self.settings_path}') + get_logger().info(f"Saved settings to {self.settings_path}") - get_logger().info(f'Loaded settings from {self.settings_path}') + get_logger().info(f"Loaded settings from {self.settings_path}") else: # Set everything to its default value for key, value in self.settings.items(): value.set_to_default() - with open(self.settings_path, 'w', encoding='UTF-8') as file_pointer: + with open(self.settings_path, "w", encoding="UTF-8") as file_pointer: json.dump(self._json(), file_pointer, indent=4) - get_logger().info(f'Saved settings to {self.settings_path}') + get_logger().info(f"Saved settings to {self.settings_path}") def _save(self): self._file_watcher.pause() - with open(self.settings_path, 'w', encoding='UTF-8') as file_pointer: + with open(self.settings_path, "w", encoding="UTF-8") as file_pointer: json.dump(self._json(), file_pointer, indent=4) file_pointer.flush() os.fsync(file_pointer.fileno()) self._file_watcher.update_m_time() self._file_watcher.resume() - get_logger().info(f'Saved settings to {self.settings_path}') + get_logger().info(f"Saved settings to {self.settings_path}") async def get(self, key): async with self.lock: if key not in self.settings: - raise KeyError(f'{key} is not a valid setting') + raise KeyError(f"{key} is not a valid setting") return self.settings.get(key).get_value() async def set(self, key, item): changed = False async with self.lock: if key not in self.settings: - get_logger().warning(f'Setting {key} found in settings file is not ' - f'defined, ignoring!') + get_logger().warning( + f"Setting {key} found in settings file is not " + f"defined, ignoring!" + ) else: changed = self.settings[key].set_value(item) if changed: self._save() settings = await self.as_json() - await self._application.ws_send_to_all('SETTINGS_CHANGED', 'WS_SETTINGS_CHANGED', - settings) + await self._application.ws_send_to_all( + "SETTINGS_CHANGED", "WS_SETTINGS_CHANGED", settings + ) def _json(self): settings_json = {} diff --git a/server/main.py b/server/main.py index 5bb44f9c..09e50863 100755 --- a/server/main.py +++ b/server/main.py @@ -1,38 +1,37 @@ #!/usr/bin/env python3 import asyncio import logging + from tornado.options import define, options, parse_command_line -from digi_server.logger import get_logger, add_logging_level from digi_server.app_server import DigiScriptServer +from digi_server.logger import add_logging_level, get_logger -add_logging_level('TRACE', logging.DEBUG - 5) +add_logging_level("TRACE", logging.DEBUG - 5) get_logger().setLevel(logging.DEBUG) -define('debug', type=bool, default=True, help='auto reload') -define('port', type=int, default=8080, help='port to listen on') -define( - 'settings_path', - type=str, - default=None, - help='Path to settings JSON file') -define('skip_migrations', type=bool, default=False, help='skip database migrations') +define("debug", type=bool, default=True, help="auto reload") +define("port", type=int, default=8080, help="port to listen on") +define("settings_path", type=str, default=None, help="Path to settings JSON file") +define("skip_migrations", type=bool, default=False, help="skip database migrations") async def main(): parse_command_line() - app = DigiScriptServer(debug=options.debug, - settings_path=options.settings_path, - skip_migrations=options.skip_migrations) + app = DigiScriptServer( + debug=options.debug, + settings_path=options.settings_path, + skip_migrations=options.skip_migrations, + ) await app.configure() app.listen(options.port) - get_logger().info(f'Listening on port: {options.port}') + get_logger().info(f"Listening on port: {options.port}") if options.debug: - get_logger().warning('Running in debug mode') + get_logger().warning("Running in debug mode") await asyncio.Event().wait() -if __name__ == '__main__': +if __name__ == "__main__": asyncio.run(main()) diff --git a/server/models/cue.py b/server/models/cue.py index 10e02dc5..18e79f6d 100644 --- a/server/models/cue.py +++ b/server/models/cue.py @@ -1,55 +1,68 @@ -from sqlalchemy import Column, Integer, ForeignKey, String -from sqlalchemy.orm import relationship, backref +from sqlalchemy import Column, ForeignKey, Integer, String +from sqlalchemy.orm import backref, relationship from models.models import db -from models.script import ScriptRevision, ScriptLine +from models.script import ScriptLine, ScriptRevision from utils.database import DeleteMixin class CueType(db.Model): - __tablename__ = 'cuetypes' + __tablename__ = "cuetypes" id = Column(Integer, primary_key=True, autoincrement=True) - show_id = Column(Integer, ForeignKey('shows.id')) + show_id = Column(Integer, ForeignKey("shows.id")) prefix = Column(String(5)) description = Column(String(100)) colour = Column(String()) - cues = relationship('Cue', back_populates='cue_type', cascade='all, delete-orphan') + cues = relationship("Cue", back_populates="cue_type", cascade="all, delete-orphan") class Cue(db.Model): - __tablename__ = 'cue' + __tablename__ = "cue" id = Column(Integer, primary_key=True, autoincrement=True) - cue_type_id = Column(Integer, ForeignKey('cuetypes.id')) + cue_type_id = Column(Integer, ForeignKey("cuetypes.id")) ident = Column(String()) - cue_type = relationship('CueType', uselist=False, foreign_keys=[cue_type_id], - back_populates='cues') + cue_type = relationship( + "CueType", uselist=False, foreign_keys=[cue_type_id], back_populates="cues" + ) class CueAssociation(db.Model, DeleteMixin): - __tablename__ = 'script_cue_association' - __mapper_args__ = { - 'confirm_deleted_rows': False - } - - revision_id = Column(Integer, ForeignKey('script_revisions.id'), primary_key=True, index=True) - line_id = Column(Integer, ForeignKey('script_lines.id'), primary_key=True, index=True) - cue_id = Column(Integer, ForeignKey('cue.id'), primary_key=True, index=True) - - revision: ScriptRevision = relationship('ScriptRevision', foreign_keys=[revision_id], - uselist=False, - backref=backref('cue_associations', uselist=True, - cascade='all, delete-orphan')) - line: ScriptLine = relationship('ScriptLine', foreign_keys=[line_id], uselist=False, - backref=backref('cue_associations', uselist=True, - viewonly=True)) - cue: Cue = relationship('Cue', foreign_keys=[cue_id], uselist=False, - backref=backref('revision_associations', uselist=True, - cascade='all, delete-orphan')) + __tablename__ = "script_cue_association" + __mapper_args__ = {"confirm_deleted_rows": False} + + revision_id = Column( + Integer, ForeignKey("script_revisions.id"), primary_key=True, index=True + ) + line_id = Column( + Integer, ForeignKey("script_lines.id"), primary_key=True, index=True + ) + cue_id = Column(Integer, ForeignKey("cue.id"), primary_key=True, index=True) + + revision: ScriptRevision = relationship( + "ScriptRevision", + foreign_keys=[revision_id], + uselist=False, + backref=backref("cue_associations", uselist=True, cascade="all, delete-orphan"), + ) + line: ScriptLine = relationship( + "ScriptLine", + foreign_keys=[line_id], + uselist=False, + backref=backref("cue_associations", uselist=True, viewonly=True), + ) + cue: Cue = relationship( + "Cue", + foreign_keys=[cue_id], + uselist=False, + backref=backref( + "revision_associations", uselist=True, cascade="all, delete-orphan" + ), + ) def pre_delete(self, session): if len(self.cue.revision_associations) == 1: diff --git a/server/models/mics.py b/server/models/mics.py index 0c712819..203c1682 100644 --- a/server/models/mics.py +++ b/server/models/mics.py @@ -1,32 +1,38 @@ -from sqlalchemy import Column, Integer, ForeignKey, String -from sqlalchemy.orm import relationship, backref +from sqlalchemy import Column, ForeignKey, Integer, String +from sqlalchemy.orm import backref, relationship from models.models import db class Microphone(db.Model): - __tablename__ = 'microphones' + __tablename__ = "microphones" id = Column(Integer, primary_key=True, autoincrement=True) - show_id = Column(Integer, ForeignKey('shows.id')) + show_id = Column(Integer, ForeignKey("shows.id")) name = Column(String) description = Column(String) class MicrophoneAllocation(db.Model): - __tablename__ = 'microphone_allocations' - - mic_id = Column(Integer, ForeignKey('microphones.id'), primary_key=True) - scene_id = Column(Integer, ForeignKey('scene.id'), primary_key=True) - character_id = Column(Integer, ForeignKey('character.id'), primary_key=True) - - microphone = relationship('Microphone', uselist=False, - backref=backref('allocations', uselist=True, - cascade='all, delete-orphan')) - scene = relationship('Scene', uselist=False, - backref=backref('mic_allocations', uselist=True, - cascade='all, delete-orphan')) - character = relationship('Character', uselist=False, - backref=backref('mic_allocations', uselist=True, - cascade='all, delete-orphan')) + __tablename__ = "microphone_allocations" + + mic_id = Column(Integer, ForeignKey("microphones.id"), primary_key=True) + scene_id = Column(Integer, ForeignKey("scene.id"), primary_key=True) + character_id = Column(Integer, ForeignKey("character.id"), primary_key=True) + + microphone = relationship( + "Microphone", + uselist=False, + backref=backref("allocations", uselist=True, cascade="all, delete-orphan"), + ) + scene = relationship( + "Scene", + uselist=False, + backref=backref("mic_allocations", uselist=True, cascade="all, delete-orphan"), + ) + character = relationship( + "Character", + uselist=False, + backref=backref("mic_allocations", uselist=True, cascade="all, delete-orphan"), + ) diff --git a/server/models/models.py b/server/models/models.py index 4eb2fde8..53a8f1ef 100644 --- a/server/models/models.py +++ b/server/models/models.py @@ -8,10 +8,10 @@ def import_all_models(): - models = find_end_modules('.', prefix='models') + models = find_end_modules(".", prefix="models") for model in models: if model != __name__: - get_logger().debug(f'Importing model module {model}') + get_logger().debug(f"Importing model module {model}") mod = importlib.import_module(model) IMPORTED_MODELS[model] = mod diff --git a/server/models/script.py b/server/models/script.py index d9bea858..461d6a7c 100644 --- a/server/models/script.py +++ b/server/models/script.py @@ -1,80 +1,91 @@ -from sqlalchemy import Column, Integer, ForeignKey, DateTime, String, Boolean -from sqlalchemy.orm import relationship, backref +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String +from sqlalchemy.orm import backref, relationship from models.models import db from utils.database import DeleteMixin class Script(db.Model): - __tablename__ = 'script' + __tablename__ = "script" id = Column(Integer, primary_key=True, autoincrement=True) - show_id = Column(Integer, ForeignKey('shows.id')) - current_revision = Column(Integer, ForeignKey('script_revisions.id')) + show_id = Column(Integer, ForeignKey("shows.id")) + current_revision = Column(Integer, ForeignKey("script_revisions.id")) - revisions = relationship('ScriptRevision', uselist=True, - primaryjoin='ScriptRevision.script_id == Script.id') + revisions = relationship( + "ScriptRevision", + uselist=True, + primaryjoin="ScriptRevision.script_id == Script.id", + ) class ScriptRevision(db.Model): - __tablename__ = 'script_revisions' - __mapper_args__ = { - 'confirm_deleted_rows': False - } + __tablename__ = "script_revisions" + __mapper_args__ = {"confirm_deleted_rows": False} id = Column(Integer, primary_key=True, autoincrement=True) - script_id = Column(Integer, ForeignKey('script.id')) + script_id = Column(Integer, ForeignKey("script.id")) revision = Column(Integer) created_at = Column(DateTime) edited_at = Column(DateTime) description = Column(String) - previous_revision_id = Column(Integer, ForeignKey('script_revisions.id', ondelete='SET NULL')) + previous_revision_id = Column( + Integer, ForeignKey("script_revisions.id", ondelete="SET NULL") + ) - previous_revision = relationship('ScriptRevision', foreign_keys=[previous_revision_id]) + previous_revision = relationship( + "ScriptRevision", foreign_keys=[previous_revision_id] + ) class ScriptLine(db.Model): - __tablename__ = 'script_lines' + __tablename__ = "script_lines" id = Column(Integer, primary_key=True, autoincrement=True) - act_id = Column(Integer, ForeignKey('act.id')) - scene_id = Column(Integer, ForeignKey('scene.id')) + act_id = Column(Integer, ForeignKey("act.id")) + scene_id = Column(Integer, ForeignKey("scene.id")) page = Column(Integer, index=True) stage_direction = Column(Boolean) - stage_direction_style_id = Column(Integer, ForeignKey('stage_direction_styles.id', ondelete='SET NULL')) + stage_direction_style_id = Column( + Integer, ForeignKey("stage_direction_styles.id", ondelete="SET NULL") + ) - act = relationship('Act', uselist=False, back_populates='lines') - scene = relationship('Scene', uselist=False, back_populates='lines') - stage_direction_style = relationship('StageDirectionStyle', uselist=False) + act = relationship("Act", uselist=False, back_populates="lines") + scene = relationship("Scene", uselist=False, back_populates="lines") + stage_direction_style = relationship("StageDirectionStyle", uselist=False) class ScriptLineRevisionAssociation(db.Model, DeleteMixin): - __tablename__ = 'script_line_revision_association' - __mapper_args__ = { - 'confirm_deleted_rows': False - } - - revision_id = Column(Integer, ForeignKey('script_revisions.id'), primary_key=True, index=True) - line_id = Column(Integer, ForeignKey('script_lines.id'), primary_key=True, index=True) - - next_line_id = Column(Integer, ForeignKey('script_lines.id')) - previous_line_id = Column(Integer, ForeignKey('script_lines.id')) - - revision: ScriptRevision = relationship('ScriptRevision', - foreign_keys=[revision_id], - uselist=False, - backref=backref('line_associations', - uselist=True, - cascade='all, delete')) - line: ScriptLine = relationship('ScriptLine', - foreign_keys=[line_id], - uselist=False, - backref=backref('revision_associations', - uselist=True, - cascade='all, delete')) - next_line: ScriptLine = relationship('ScriptLine', foreign_keys=[next_line_id]) - previous_line: ScriptLine = relationship('ScriptLine', foreign_keys=[previous_line_id]) + __tablename__ = "script_line_revision_association" + __mapper_args__ = {"confirm_deleted_rows": False} + + revision_id = Column( + Integer, ForeignKey("script_revisions.id"), primary_key=True, index=True + ) + line_id = Column( + Integer, ForeignKey("script_lines.id"), primary_key=True, index=True + ) + + next_line_id = Column(Integer, ForeignKey("script_lines.id")) + previous_line_id = Column(Integer, ForeignKey("script_lines.id")) + + revision: ScriptRevision = relationship( + "ScriptRevision", + foreign_keys=[revision_id], + uselist=False, + backref=backref("line_associations", uselist=True, cascade="all, delete"), + ) + line: ScriptLine = relationship( + "ScriptLine", + foreign_keys=[line_id], + uselist=False, + backref=backref("revision_associations", uselist=True, cascade="all, delete"), + ) + next_line: ScriptLine = relationship("ScriptLine", foreign_keys=[next_line_id]) + previous_line: ScriptLine = relationship( + "ScriptLine", foreign_keys=[previous_line_id] + ) def pre_delete(self, session): if self.line and len(self.line.revision_associations) == 1: @@ -85,42 +96,55 @@ def post_delete(self, session): class ScriptLinePart(db.Model): - __tablename__ = 'script_line_parts' + __tablename__ = "script_line_parts" id = Column(Integer, primary_key=True, autoincrement=True) - line_id = Column(Integer, ForeignKey('script_lines.id')) + line_id = Column(Integer, ForeignKey("script_lines.id")) part_index = Column(Integer) - character_id = Column(Integer, ForeignKey('character.id')) - character_group_id = Column(Integer, ForeignKey('character_group.id')) + character_id = Column(Integer, ForeignKey("character.id")) + character_group_id = Column(Integer, ForeignKey("character_group.id")) line_text = Column(String) - line = relationship('ScriptLine', uselist=False, foreign_keys=[line_id], - backref=backref('line_parts', uselist=True, - cascade='all, delete-orphan')) - character = relationship('Character', uselist=False) - character_group = relationship('CharacterGroup', uselist=False) + line = relationship( + "ScriptLine", + uselist=False, + foreign_keys=[line_id], + backref=backref("line_parts", uselist=True, cascade="all, delete-orphan"), + ) + character = relationship("Character", uselist=False) + character_group = relationship("CharacterGroup", uselist=False) class ScriptCuts(db.Model): - __tablename__ = 'script_line_cuts' - - line_part_id = Column(Integer, ForeignKey('script_line_parts.id'), primary_key=True, index=True) - revision_id = Column(Integer, ForeignKey('script_revisions.id'), primary_key=True, index=True) - - line_part = relationship('ScriptLinePart', uselist=False, foreign_keys=[line_part_id], - backref=backref('line_part_cuts', uselist=False, - cascade='all, delete-orphan')) - revision = relationship('ScriptRevision', uselist=False, foreign_keys=[revision_id], - backref=backref('line_part_cuts', uselist=True, - cascade='all, delete-orphan')) + __tablename__ = "script_line_cuts" + + line_part_id = Column( + Integer, ForeignKey("script_line_parts.id"), primary_key=True, index=True + ) + revision_id = Column( + Integer, ForeignKey("script_revisions.id"), primary_key=True, index=True + ) + + line_part = relationship( + "ScriptLinePart", + uselist=False, + foreign_keys=[line_part_id], + backref=backref("line_part_cuts", uselist=False, cascade="all, delete-orphan"), + ) + revision = relationship( + "ScriptRevision", + uselist=False, + foreign_keys=[revision_id], + backref=backref("line_part_cuts", uselist=True, cascade="all, delete-orphan"), + ) class StageDirectionStyle(db.Model): - __tablename__ = 'stage_direction_styles' + __tablename__ = "stage_direction_styles" id = Column(Integer, primary_key=True, autoincrement=True) - script_id = Column(Integer, ForeignKey('script.id'), index=True) + script_id = Column(Integer, ForeignKey("script.id"), index=True) description = Column(String) bold = Column(Boolean) @@ -131,5 +155,11 @@ class StageDirectionStyle(db.Model): enable_background_colour = Column(Boolean) background_colour = Column(String) - script=relationship('Script', uselist=False, foreign_keys=[script_id], - backref=backref('stage_direction_styles', uselist=True, cascade='all, delete-orphan')) + script = relationship( + "Script", + uselist=False, + foreign_keys=[script_id], + backref=backref( + "stage_direction_styles", uselist=True, cascade="all, delete-orphan" + ), + ) diff --git a/server/models/session.py b/server/models/session.py index 3a7ae4b0..c59414df 100644 --- a/server/models/session.py +++ b/server/models/session.py @@ -1,37 +1,42 @@ -from sqlalchemy import Column, String, Float, Boolean, Integer, ForeignKey, DateTime -from sqlalchemy.orm import relationship, backref +from sqlalchemy import Boolean, Column, DateTime, Float, ForeignKey, Integer, String +from sqlalchemy.orm import backref, relationship from models.models import db class Session(db.Model): - __tablename__ = 'sessions' + __tablename__ = "sessions" internal_id = Column(String(255), primary_key=True) remote_ip = Column(String(255)) last_ping = Column(Float()) last_pong = Column(Float()) is_editor = Column(Boolean(), default=False, index=True) - user_id = Column(Integer, ForeignKey('user.id'), index=True) + user_id = Column(Integer, ForeignKey("user.id"), index=True) - user = relationship('User', uselist=False, - backref=backref('sessions', uselist=True)) + user = relationship( + "User", uselist=False, backref=backref("sessions", uselist=True) + ) class ShowSession(db.Model): - __tablename__ = 'showsession' + __tablename__ = "showsession" id = Column(Integer, primary_key=True, autoincrement=True) - show_id = Column(Integer, ForeignKey('shows.id')) + show_id = Column(Integer, ForeignKey("shows.id")) start_date_time = Column(DateTime()) end_date_time = Column(DateTime()) - user_id = Column(Integer, ForeignKey('user.id'), index=True) - client_internal_id = Column(String(255), ForeignKey('sessions.internal_id')) + user_id = Column(Integer, ForeignKey("user.id"), index=True) + client_internal_id = Column(String(255), ForeignKey("sessions.internal_id")) last_client_internal_id = Column(String(255)) latest_line_ref = Column(String) - show = relationship('Show', uselist=False, foreign_keys=[show_id]) - user = relationship('User', uselist=False, foreign_keys=[user_id]) - client = relationship('Session', uselist=False, foreign_keys=[client_internal_id], - backref=backref('live_session', uselist=False)) + show = relationship("Show", uselist=False, foreign_keys=[show_id]) + user = relationship("User", uselist=False, foreign_keys=[user_id]) + client = relationship( + "Session", + uselist=False, + foreign_keys=[client_internal_id], + backref=backref("live_session", uselist=False), + ) diff --git a/server/models/show.py b/server/models/show.py index 0fe3ab54..8c4b1682 100644 --- a/server/models/show.py +++ b/server/models/show.py @@ -1,11 +1,20 @@ -from sqlalchemy import Column, Integer, String, Date, DateTime, ForeignKey, Table, Boolean -from sqlalchemy.orm import relationship, backref +from sqlalchemy import ( + Boolean, + Column, + Date, + DateTime, + ForeignKey, + Integer, + String, + Table, +) +from sqlalchemy.orm import backref, relationship from models.models import db class Show(db.Model): - __tablename__ = 'shows' + __tablename__ = "shows" id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String(100)) @@ -13,103 +22,123 @@ class Show(db.Model): end_date = Column(Date()) created_at = Column(DateTime()) edited_at = Column(DateTime()) - first_act_id = Column(Integer, ForeignKey('act.id')) - current_session_id = Column(Integer, ForeignKey('showsession.id')) + first_act_id = Column(Integer, ForeignKey("act.id")) + current_session_id = Column(Integer, ForeignKey("showsession.id")) # Relationships - first_act = relationship('Act', uselist=False, foreign_keys=[first_act_id]) - current_session = relationship('ShowSession', uselist=False, foreign_keys=[current_session_id]) - - cast_list = relationship("Cast", - cascade='all, delete-orphan') - character_list = relationship('Character', - cascade='all, delete-orphan') - character_group_list = relationship('CharacterGroup', - cascade='all, delete-orphan') - act_list = relationship('Act', primaryjoin=lambda: Show.id == Act.show_id, - cascade='all, delete-orphan') - scene_list = relationship('Scene', - cascade='all, delete-orphan') - cue_type_list = relationship('CueType', - cascade='all, delete-orphan') - users = relationship('User', uselist=True, cascade='all, delete-orphan') + first_act = relationship("Act", uselist=False, foreign_keys=[first_act_id]) + current_session = relationship( + "ShowSession", uselist=False, foreign_keys=[current_session_id] + ) + + cast_list = relationship("Cast", cascade="all, delete-orphan") + character_list = relationship("Character", cascade="all, delete-orphan") + character_group_list = relationship("CharacterGroup", cascade="all, delete-orphan") + act_list = relationship( + "Act", primaryjoin=lambda: Show.id == Act.show_id, cascade="all, delete-orphan" + ) + scene_list = relationship("Scene", cascade="all, delete-orphan") + cue_type_list = relationship("CueType", cascade="all, delete-orphan") + users = relationship("User", uselist=True, cascade="all, delete-orphan") class Cast(db.Model): - __tablename__ = 'cast' + __tablename__ = "cast" id = Column(Integer, primary_key=True, autoincrement=True) - show_id = Column(Integer, ForeignKey('shows.id')) + show_id = Column(Integer, ForeignKey("shows.id")) first_name = Column(String) last_name = Column(String) # Relationships - character_list = relationship('Character', back_populates='cast_member') + character_list = relationship("Character", back_populates="cast_member") character_group_association_table = Table( "character_group_association", db.Model.metadata, - Column('character_group_id', ForeignKey('character_group.id'), primary_key=True), - Column('character_id', ForeignKey('character.id'), primary_key=True) + Column("character_group_id", ForeignKey("character_group.id"), primary_key=True), + Column("character_id", ForeignKey("character.id"), primary_key=True), ) class Character(db.Model): - __tablename__ = 'character' + __tablename__ = "character" id = Column(Integer, primary_key=True, autoincrement=True) - show_id = Column(Integer, ForeignKey('shows.id')) - played_by = Column(Integer, ForeignKey('cast.id')) + show_id = Column(Integer, ForeignKey("shows.id")) + played_by = Column(Integer, ForeignKey("cast.id")) name = Column(String) description = Column(String) - cast_member = relationship("Cast", back_populates='character_list') - character_groups = relationship('CharacterGroup', secondary=character_group_association_table, - back_populates='characters') + cast_member = relationship("Cast", back_populates="character_list") + character_groups = relationship( + "CharacterGroup", + secondary=character_group_association_table, + back_populates="characters", + ) class CharacterGroup(db.Model): - __tablename__ = 'character_group' + __tablename__ = "character_group" id = Column(Integer, primary_key=True, autoincrement=True) - show_id = Column(Integer, ForeignKey('shows.id')) + show_id = Column(Integer, ForeignKey("shows.id")) name = Column(String) description = Column(String) - characters = relationship('Character', secondary=character_group_association_table, - back_populates='character_groups') + characters = relationship( + "Character", + secondary=character_group_association_table, + back_populates="character_groups", + ) class Act(db.Model): - __tablename__ = 'act' + __tablename__ = "act" id = Column(Integer, primary_key=True, autoincrement=True) - show_id = Column(Integer, ForeignKey('shows.id')) + show_id = Column(Integer, ForeignKey("shows.id")) name = Column(String) interval_after = Column(Boolean) - first_scene_id = Column(Integer, ForeignKey('scene.id')) - previous_act_id = Column(Integer, ForeignKey('act.id')) + first_scene_id = Column(Integer, ForeignKey("scene.id")) + previous_act_id = Column(Integer, ForeignKey("act.id")) - first_scene = relationship('Scene', uselist=False, foreign_keys=[first_scene_id]) - previous_act = relationship('Act', uselist=False, remote_side=[id], - backref=backref('next_act', uselist=False)) - lines = relationship('ScriptLine', back_populates='act', cascade='all, delete-orphan') + first_scene = relationship("Scene", uselist=False, foreign_keys=[first_scene_id]) + previous_act = relationship( + "Act", + uselist=False, + remote_side=[id], + backref=backref("next_act", uselist=False), + ) + lines = relationship( + "ScriptLine", back_populates="act", cascade="all, delete-orphan" + ) class Scene(db.Model): - __tablename__ = 'scene' + __tablename__ = "scene" id = Column(Integer, primary_key=True, autoincrement=True) - show_id = Column(Integer, ForeignKey('shows.id')) - act_id = Column(Integer, ForeignKey('act.id')) + show_id = Column(Integer, ForeignKey("shows.id")) + act_id = Column(Integer, ForeignKey("act.id")) name = Column(String) - previous_scene_id = Column(Integer, ForeignKey('scene.id')) - - act = relationship('Act', uselist=False, - backref=backref('scene_list', cascade='all, delete-orphan'), - foreign_keys=[act_id], post_update=True) - previous_scene = relationship('Scene', uselist=False, remote_side=[id], - backref=backref('next_scene', uselist=False)) - lines = relationship('ScriptLine', back_populates='scene', cascade='all, delete-orphan') + previous_scene_id = Column(Integer, ForeignKey("scene.id")) + + act = relationship( + "Act", + uselist=False, + backref=backref("scene_list", cascade="all, delete-orphan"), + foreign_keys=[act_id], + post_update=True, + ) + previous_scene = relationship( + "Scene", + uselist=False, + remote_side=[id], + backref=backref("next_scene", uselist=False), + ) + lines = relationship( + "ScriptLine", back_populates="scene", cascade="all, delete-orphan" + ) diff --git a/server/models/user.py b/server/models/user.py index dfb81ae0..45ae8c69 100644 --- a/server/models/user.py +++ b/server/models/user.py @@ -1,14 +1,14 @@ -from sqlalchemy import Integer, Column, String, ForeignKey, Boolean, DateTime +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String from models.models import db class User(db.Model): - __tablename__ = 'user' + __tablename__ = "user" id = Column(Integer(), primary_key=True, autoincrement=True) username = Column(String(), index=True) password = Column(String()) - show_id = Column(Integer(), ForeignKey('shows.id'), index=True) + show_id = Column(Integer(), ForeignKey("shows.id"), index=True) is_admin = Column(Boolean()) last_login = Column(DateTime()) diff --git a/server/rbac/rbac.py b/server/rbac/rbac.py index 7fa9b4bc..6e762bb0 100644 --- a/server/rbac/rbac.py +++ b/server/rbac/rbac.py @@ -1,4 +1,4 @@ -from typing import Optional, TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional from models.models import db from rbac.exceptions import RBACException @@ -12,20 +12,22 @@ class RBACController: - def __init__(self, app: 'DigiScriptServer'): + def __init__(self, app: "DigiScriptServer"): self.app = app self._rbac_db = RBACDatabase(app.get_db(), app) self._display_fields = {} - def add_mapping(self, actor: type, resource: type, display_fields: Optional[List] = None) -> None: + def add_mapping( + self, actor: type, resource: type, display_fields: Optional[List] = None + ) -> None: if display_fields is None: display_fields = [] if not get_registry().get_schema_by_model(actor): - raise RBACException('actor does not have a registered schema') + raise RBACException("actor does not have a registered schema") if not get_registry().get_schema_by_model(resource): - raise RBACException('resource does not have a registered schema') + raise RBACException("resource does not have a registered schema") if len(display_fields) > 3: - raise RBACException('Only 3 or fewer display fields are allowed') + raise RBACException("Only 3 or fewer display fields are allowed") self._rbac_db.add_mapping(actor, resource) self._display_fields[resource] = [field.key for field in display_fields] diff --git a/server/rbac/rbac_db.py b/server/rbac/rbac_db.py index cca69de7..30fd0d2d 100644 --- a/server/rbac/rbac_db.py +++ b/server/rbac/rbac_db.py @@ -1,10 +1,10 @@ import functools from collections import defaultdict from copy import deepcopy -from typing import Optional, List, TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, List, Optional from anytree import Node -from sqlalchemy import inspect, Column, ForeignKey, Integer, TypeDecorator, Table +from sqlalchemy import Column, ForeignKey, Integer, Table, TypeDecorator, inspect from digi_server.logger import get_logger from models.models import db @@ -27,12 +27,16 @@ class RoleCol(TypeDecorator): def process_bind_param(self, value, dialect): if not isinstance(value, Role): - raise Exception(f'RoleCol data type is incorrect. Got {type(value)} but should be Role') + raise Exception( + f"RoleCol data type is incorrect. Got {type(value)} but should be Role" + ) return value.value if value is not None else None def process_literal_param(self, value, dialect): if not isinstance(value, Role): - raise Exception(f'RoleCol data type is incorrect. Got {type(value)} but should be Role') + raise Exception( + f"RoleCol data type is incorrect. Got {type(value)} but should be Role" + ) return value.value if value is not None else None @property @@ -47,20 +51,28 @@ def _get_mapping_columns(actor: db.Model, resource: db.Model) -> dict: actor_inspect = inspect(actor) resource_inspect = inspect(resource) cols = {} - cols.update({ - f'{actor_inspect.mapper.mapped_table.fullname}_{col.key}': getattr(actor, col.key) for col in - actor_inspect.mapper.primary_key - }) - cols.update({ - f'{resource_inspect.mapper.mapped_table.fullname}_{col.key}': getattr(resource, col.key) for col in - resource_inspect.mapper.primary_key - }) + cols.update( + { + f"{actor_inspect.mapper.mapped_table.fullname}_{col.key}": getattr( + actor, col.key + ) + for col in actor_inspect.mapper.primary_key + } + ) + cols.update( + { + f"{resource_inspect.mapper.mapped_table.fullname}_{col.key}": getattr( + resource, col.key + ) + for col in resource_inspect.mapper.primary_key + } + ) return cols class RBACDatabase: - def __init__(self, _db: DigiSQLAlchemy, app: 'DigiScriptServer'): + def __init__(self, _db: DigiSQLAlchemy, app: "DigiScriptServer"): self._db: DigiSQLAlchemy = _db self._app = app self._mappings = {} @@ -69,58 +81,67 @@ def __init__(self, _db: DigiSQLAlchemy, app: 'DigiScriptServer'): def add_mapping(self, actor: type, resource: type) -> None: if not isinstance(actor, type): - raise RBACException('actor must be class object, not instance') + raise RBACException("actor must be class object, not instance") if not isinstance(resource, type): - raise RBACException('resource must be class object, not instance') + raise RBACException("resource must be class object, not instance") actor_inspect = inspect(actor) resource_inspect = inspect(resource) if not self._has_link_to_show(actor_inspect.mapped_table): - raise RBACException('actor class does not have a reference back to Show table') + raise RBACException( + "actor class does not have a reference back to Show table" + ) if not self._has_link_to_show(resource_inspect.mapped_table): - raise RBACException('resource class does not have a reference back to Show table') + raise RBACException( + "resource class does not have a reference back to Show table" + ) - table_name = f'rbac_{actor_inspect.mapped_table.fullname}_{resource_inspect.mapped_table.fullname}' + table_name = f"rbac_{actor_inspect.mapped_table.fullname}_{resource_inspect.mapped_table.fullname}" if table_name in self._mappings: - raise RBACException(f'RBAC mapping {table_name} already exists') + raise RBACException(f"RBAC mapping {table_name} already exists") actor_columns = { - f'{actor_inspect.mapped_table.fullname}_{col.key}': Column(col.type, ForeignKey( - f'{actor_inspect.mapped_table.fullname}.{col.key}'), primary_key=True) for col in - actor_inspect.primary_key + f"{actor_inspect.mapped_table.fullname}_{col.key}": Column( + col.type, + ForeignKey(f"{actor_inspect.mapped_table.fullname}.{col.key}"), + primary_key=True, + ) + for col in actor_inspect.primary_key } resource_columns = { - f'{resource_inspect.mapped_table.fullname}_{col.key}': Column(col.type, ForeignKey( - f'{resource_inspect.mapped_table.fullname}.{col.key}'), primary_key=True) for col in - resource_inspect.primary_key + f"{resource_inspect.mapped_table.fullname}_{col.key}": Column( + col.type, + ForeignKey(f"{resource_inspect.mapped_table.fullname}.{col.key}"), + primary_key=True, + ) + for col in resource_inspect.primary_key } - attr_dict = { - '__tablename__': table_name, - 'rbac_permissions': Column(RoleCol()) - } + attr_dict = {"__tablename__": table_name, "rbac_permissions": Column(RoleCol())} attr_dict.update(actor_columns) attr_dict.update(resource_columns) rbac_class = type(table_name, (db.Model,), attr_dict) self._mappings[table_name] = rbac_class self._resource_mappings[actor_inspect.mapped_table.fullname].append(resource) - logger.info(f'Created RBAC mapping {table_name}') + logger.info(f"Created RBAC mapping {table_name}") def _validate_mapping(self, actor: db.Model, resource: db.Model) -> str: if not isinstance(actor, db.Model): - raise RBACException('actor must be class instance, not object') + raise RBACException("actor must be class instance, not object") if not isinstance(resource, db.Model): - raise RBACException('resource must be class instance, not object') + raise RBACException("resource must be class instance, not object") actor_inspect = inspect(actor) resource_inspect = inspect(resource) - table_name = (f'rbac_{actor_inspect.mapper.mapped_table.fullname}_' - f'{resource_inspect.mapper.mapped_table.fullname}') + table_name = ( + f"rbac_{actor_inspect.mapper.mapped_table.fullname}_" + f"{resource_inspect.mapper.mapped_table.fullname}" + ) if table_name not in self._mappings: - raise RBACException('Mapping for actor and resource not created') + raise RBACException("Mapping for actor and resource not created") return table_name @@ -145,7 +166,9 @@ def revoke_role(self, actor: db.Model, resource: db.Model, role: Role) -> None: with self._db.sessionmaker() as session: rbac_assignment = session.query(self._mappings[table_name]).get(cols) if not rbac_assignment: - raise RBACException('actor does not have any roles assigned for the resource') + raise RBACException( + "actor does not have any roles assigned for the resource" + ) rbac_assignment.rbac_permissions &= ~role session.commit() @@ -173,7 +196,9 @@ def get_all_roles(self, actor: db.Model) -> Dict: for resource in resources: resource_inspect = inspect(resource) for rbac_object in self.get_objects_for_resource(resource): - roles[resource_inspect.mapped_table.fullname].append([rbac_object, self.get_roles(actor, rbac_object)]) + roles[resource_inspect.mapped_table.fullname].append( + [rbac_object, self.get_roles(actor, rbac_object)] + ) return roles @functools.lru_cache() @@ -187,8 +212,10 @@ def __has_link_to_show(self, table: Table, checked_tables=None): return True if table.foreign_key_constraints and table.fullname not in checked_tables: checked_tables.append(table.fullname) - return any(self.__has_link_to_show(fkc.referred_table, checked_tables) - for fkc in table.foreign_key_constraints) + return any( + self.__has_link_to_show(fkc.referred_table, checked_tables) + for fkc in table.foreign_key_constraints + ) return False @functools.lru_cache() @@ -198,7 +225,7 @@ def _get_link_to_show(self, table: Table): root = Node(table.fullname, table=table) self.__get_link_to_show(table, root) - return tree.flatten(root, attr='table') + return tree.flatten(root, attr="table") def __get_link_to_show(self, table: Table, root: Node, checked_tables=None): if checked_tables is None: @@ -217,13 +244,17 @@ def __get_link_to_show(self, table: Table, root: Node, checked_tables=None): if self._has_link_to_show(fkc.referred_table): self.__get_link_to_show( fkc.referred_table, - Node(fkc.referred_table.fullname, table=fkc.referred_table, parent=root), - checked_tables + Node( + fkc.referred_table.fullname, + table=fkc.referred_table, + parent=root, + ), + checked_tables, ) def get_resources_for_actor(self, actor: db.Model) -> Optional[List[db.Model]]: if not isinstance(actor, type): - raise RBACException('actor must be class object, not instance') + raise RBACException("actor must be class object, not instance") actor_inspect = inspect(actor) if actor_inspect.mapped_table.fullname in self._resource_mappings: return self._resource_mappings[actor_inspect.mapped_table.fullname] @@ -231,9 +262,9 @@ def get_resources_for_actor(self, actor: db.Model) -> Optional[List[db.Model]]: def get_objects_for_resource(self, resource: db.Model) -> Optional[List[db.Model]]: if not isinstance(resource, type): - raise RBACException('resource must be class object, not instance') + raise RBACException("resource must be class object, not instance") - current_show = self._app.digi_settings.settings.get('current_show').get_value() + current_show = self._app.digi_settings.settings.get("current_show").get_value() if not current_show: return [] @@ -265,12 +296,21 @@ def get_objects_for_resource(self, resource: db.Model) -> Optional[List[db.Model cols = {} for foreign_key in table.foreign_keys: fk_table = foreign_key.constraint.referred_table - if fk_table.fullname == previous_inspect.mapper.mapped_table.fullname: - cols[foreign_key.parent.key] = getattr(prev_entity, foreign_key.column.key) + if ( + fk_table.fullname + == previous_inspect.mapper.mapped_table.fullname + ): + cols[foreign_key.parent.key] = getattr( + prev_entity, foreign_key.column.key + ) if cols: - results.extend(session.query( - self._db.get_mapper_for_table(table.fullname)). - filter_by(**cols).all()) + results.extend( + session.query( + self._db.get_mapper_for_table(table.fullname) + ) + .filter_by(**cols) + .all() + ) previous_entities = results if valid: final_results.update(previous_entities) diff --git a/server/schemas/schemas.py b/server/schemas/schemas.py index 8625832b..bfd11a75 100644 --- a/server/schemas/schemas.py +++ b/server/schemas/schemas.py @@ -1,14 +1,21 @@ from typing import Optional -from marshmallow_sqlalchemy import SQLAlchemyAutoSchema, auto_field, SQLAlchemySchema +from marshmallow_sqlalchemy import SQLAlchemyAutoSchema, SQLAlchemySchema, auto_field from marshmallow_sqlalchemy.fields import Nested -from models.cue import CueType, Cue +from models.cue import Cue, CueType from models.mics import Microphone, MicrophoneAllocation from models.models import db -from models.script import ScriptRevision, ScriptLine, ScriptLinePart, Script, ScriptCuts, StageDirectionStyle -from models.show import Show, Cast, Character, CharacterGroup, Act, Scene +from models.script import ( + Script, + ScriptCuts, + ScriptLine, + ScriptLinePart, + ScriptRevision, + StageDirectionStyle, +) from models.session import Session, ShowSession +from models.show import Act, Cast, Character, CharacterGroup, Scene, Show from models.user import User @@ -55,7 +62,7 @@ class Meta: model = User load_instance = True include_fk = True - exclude = ('password',) + exclude = ("password",) @schema @@ -75,7 +82,9 @@ class Meta: include_relationships = True load_instance = True - character_list = Nested(lambda: CharacterSchema, many=True, exclude=('cast_member',)) + character_list = Nested( + lambda: CharacterSchema, many=True, exclude=("cast_member",) + ) @schema @@ -85,7 +94,7 @@ class Meta: include_relationships = True load_instance = True - cast_member = Nested(CastSchema, many=False, exclude=('character_list',)) + cast_member = Nested(CastSchema, many=False, exclude=("character_list",)) @schema diff --git a/server/test/test_auth_api.py b/server/test/test_auth_api.py index 702dfa4e..18a4ba85 100644 --- a/server/test/test_auth_api.py +++ b/server/test/test_auth_api.py @@ -6,144 +6,183 @@ class TestAuthAPI(DigiScriptTestCase): def test_get(self): - response = self.fetch('/api/v1/auth/create') + response = self.fetch("/api/v1/auth/create") self.assertEqual(405, response.code) def test_empty_post(self): - response = self.fetch('/api/v1/auth/create', - method='POST', - body=escape.json_encode({})) + response = self.fetch( + "/api/v1/auth/create", method="POST", body=escape.json_encode({}) + ) response_body = escape.json_decode(response.body) self.assertEqual(400, response.code) - self.assertTrue('message' in response_body) - self.assertEqual('Username missing', response_body['message']) + self.assertTrue("message" in response_body) + self.assertEqual("Username missing", response_body["message"]) def test_missing_password(self): - response = self.fetch('/api/v1/auth/create', - method='POST', - body=escape.json_encode({ - 'username': 'foobar' - })) + response = self.fetch( + "/api/v1/auth/create", + method="POST", + body=escape.json_encode({"username": "foobar"}), + ) response_body = escape.json_decode(response.body) self.assertEqual(400, response.code) - self.assertTrue('message' in response_body) - self.assertEqual('Password missing', response_body['message']) + self.assertTrue("message" in response_body) + self.assertEqual("Password missing", response_body["message"]) def test_create_admin(self): - response = self.fetch('/api/v1/auth/create', - method='POST', - body=escape.json_encode({ - 'username': 'foobar', - 'password': 'password', - 'is_admin': True, - 'show_id': None - })) + response = self.fetch( + "/api/v1/auth/create", + method="POST", + body=escape.json_encode( + { + "username": "foobar", + "password": "password", + "is_admin": True, + "show_id": None, + } + ), + ) response_body = escape.json_decode(response.body) self.assertEqual(200, response.code) - self.assertTrue('message' in response_body) - self.assertEqual('Successfully created user', response_body['message']) + self.assertTrue("message" in response_body) + self.assertEqual("Successfully created user", response_body["message"]) def test_invalid_admin(self): - response = self.fetch('/api/v1/auth/create', - method='POST', - body=escape.json_encode({ - 'username': 'foobar', - 'password': 'password', - 'is_admin': True, - 'show_id': 1 - })) + response = self.fetch( + "/api/v1/auth/create", + method="POST", + body=escape.json_encode( + { + "username": "foobar", + "password": "password", + "is_admin": True, + "show_id": 1, + } + ), + ) response_body = escape.json_decode(response.body) self.assertEqual(400, response.code) - self.assertTrue('message' in response_body) - self.assertEqual('Admin user cannot have a show allocation', response_body['message']) + self.assertTrue("message" in response_body) + self.assertEqual( + "Admin user cannot have a show allocation", response_body["message"] + ) def test_invalid_user(self): - response = self.fetch('/api/v1/auth/create', - method='POST', - body=escape.json_encode({ - 'username': 'foobar', - 'password': 'password', - 'is_admin': False, - 'show_id': None - })) + response = self.fetch( + "/api/v1/auth/create", + method="POST", + body=escape.json_encode( + { + "username": "foobar", + "password": "password", + "is_admin": False, + "show_id": None, + } + ), + ) response_body = escape.json_decode(response.body) self.assertEqual(400, response.code) - self.assertTrue('message' in response_body) - self.assertEqual('Non admin user requires a show allocation', response_body['message']) + self.assertTrue("message" in response_body) + self.assertEqual( + "Non admin user requires a show allocation", response_body["message"] + ) def test_invalid_show(self): - response = self.fetch('/api/v1/auth/create', - method='POST', - body=escape.json_encode({ - 'username': 'foobar', - 'password': 'password', - 'is_admin': False, - 'show_id': 1 - })) + response = self.fetch( + "/api/v1/auth/create", + method="POST", + body=escape.json_encode( + { + "username": "foobar", + "password": "password", + "is_admin": False, + "show_id": 1, + } + ), + ) response_body = escape.json_decode(response.body) self.assertEqual(400, response.code) - self.assertTrue('message' in response_body) - self.assertEqual('Show not found', response_body['message']) + self.assertTrue("message" in response_body) + self.assertEqual("Show not found", response_body["message"]) def test_login_success(self): - self.fetch('/api/v1/auth/create', method='POST', - body=escape.json_encode({ - 'username': 'foobar', - 'password': 'password', - 'is_admin': True, - 'show_id': None - })) - - response = self.fetch('/api/v1/auth/login', method='POST', - body=escape.json_encode({ - 'username': 'foobar', - 'password': 'password' - })) + self.fetch( + "/api/v1/auth/create", + method="POST", + body=escape.json_encode( + { + "username": "foobar", + "password": "password", + "is_admin": True, + "show_id": None, + } + ), + ) + + response = self.fetch( + "/api/v1/auth/login", + method="POST", + body=escape.json_encode({"username": "foobar", "password": "password"}), + ) response_body = escape.json_decode(response.body) self.assertEqual(200, response.code) - self.assertTrue('message' in response_body) - self.assertEqual('Successful log in', response_body['message']) + self.assertTrue("message" in response_body) + self.assertEqual("Successful log in", response_body["message"]) def test_login_invalid_password(self): - self.fetch('/api/v1/auth/create', method='POST', - body=escape.json_encode({ - 'username': 'foobar', - 'password': 'password', - 'is_admin': True, - 'show_id': None - })) - - response = self.fetch('/api/v1/auth/login', method='POST', - body=escape.json_encode({ - 'username': 'foobar', - 'password': 'wrongpassword' - })) + self.fetch( + "/api/v1/auth/create", + method="POST", + body=escape.json_encode( + { + "username": "foobar", + "password": "password", + "is_admin": True, + "show_id": None, + } + ), + ) + + response = self.fetch( + "/api/v1/auth/login", + method="POST", + body=escape.json_encode( + {"username": "foobar", "password": "wrongpassword"} + ), + ) response_body = escape.json_decode(response.body) self.assertEqual(401, response.code) - self.assertTrue('message' in response_body) - self.assertEqual('Invalid username/password', response_body['message']) + self.assertTrue("message" in response_body) + self.assertEqual("Invalid username/password", response_body["message"]) def test_login_invalid_username(self): - self.fetch('/api/v1/auth/create', method='POST', - body=escape.json_encode({ - 'username': 'foobar', - 'password': 'password', - 'is_admin': True, - 'show_id': None - })) - - response = self.fetch('/api/v1/auth/login', method='POST', - body=escape.json_encode({ - 'username': 'wrongusername', - 'password': 'password' - })) + self.fetch( + "/api/v1/auth/create", + method="POST", + body=escape.json_encode( + { + "username": "foobar", + "password": "password", + "is_admin": True, + "show_id": None, + } + ), + ) + + response = self.fetch( + "/api/v1/auth/login", + method="POST", + body=escape.json_encode( + {"username": "wrongusername", "password": "password"} + ), + ) response_body = escape.json_decode(response.body) self.assertEqual(401, response.code) - self.assertTrue('message' in response_body) - self.assertEqual('Invalid username/password', response_body['message']) + self.assertTrue("message" in response_body) + self.assertEqual("Invalid username/password", response_body["message"]) diff --git a/server/test/test_digi_server.py b/server/test/test_digi_server.py index f2017954..5d5d1483 100644 --- a/server/test/test_digi_server.py +++ b/server/test/test_digi_server.py @@ -5,24 +5,24 @@ class TestDigiScriptServer(DigiScriptTestCase): def test_debug(self): - response = self.fetch('/debug') + response = self.fetch("/debug") response_body = tornado.escape.json_decode(response.body) self.assertEqual(200, response.code) - self.assertTrue('status' in response_body) - self.assertEqual('OK', response_body['status']) + self.assertTrue("status" in response_body) + self.assertEqual("OK", response_body["status"]) def test_api_debug(self): - response = self.fetch('/api/v1/debug') + response = self.fetch("/api/v1/debug") response_body = tornado.escape.json_decode(response.body) self.assertEqual(200, response.code) - self.assertTrue('status' in response_body) - self.assertEqual('OK', response_body['status']) - self.assertTrue('api_version' in response_body) - self.assertEqual(1, response_body['api_version']) + self.assertTrue("status" in response_body) + self.assertEqual("OK", response_body["status"]) + self.assertTrue("api_version" in response_body) + self.assertEqual(1, response_body["api_version"]) def test_debug_metrics(self): - response = self.fetch('/debug/metrics') + response = self.fetch("/debug/metrics") self.assertEqual(200, response.code) diff --git a/server/test/test_settings.py b/server/test/test_settings.py index 62f06935..a2f6441d 100644 --- a/server/test/test_settings.py +++ b/server/test/test_settings.py @@ -1,6 +1,7 @@ from tornado.testing import gen_test from digi_server.logger import get_logger + from .test_utils import DigiScriptTestCase @@ -8,30 +9,33 @@ class TestSettings(DigiScriptTestCase): @gen_test def test_set_invalid_name(self): - yield self._app.digi_settings.set('not_present_key', 'some_value') - self.assertLogs(get_logger(), 'Setting not_present_key found in settings file is not ' - 'defined, ignoring!') + yield self._app.digi_settings.set("not_present_key", "some_value") + self.assertLogs( + get_logger(), + "Setting not_present_key found in settings file is not " + "defined, ignoring!", + ) @gen_test def test_get_invalid_name(self): with self.assertRaises(KeyError): - yield self._app.digi_settings.get('not_present_key') + yield self._app.digi_settings.get("not_present_key") @gen_test def test_get_set_valid_name(self): - cur_val = yield self._app.digi_settings.get('debug_mode') + cur_val = yield self._app.digi_settings.get("debug_mode") self.assertEqual(False, cur_val) - yield self._app.digi_settings.set('debug_mode', True) - cur_val = yield self._app.digi_settings.get('debug_mode') + yield self._app.digi_settings.set("debug_mode", True) + cur_val = yield self._app.digi_settings.get("debug_mode") self.assertEqual(True, cur_val) @gen_test def test_invalid_type(self): with self.assertRaises(TypeError): - yield self._app.digi_settings.set('debug_mode', 'not_a_bool') + yield self._app.digi_settings.set("debug_mode", "not_a_bool") @gen_test def test_not_nullable(self): with self.assertRaises(RuntimeError): - yield self._app.digi_settings.set('debug_mode', None) + yield self._app.digi_settings.set("debug_mode", None) diff --git a/server/test/test_utils.py b/server/test/test_utils.py index a5479b38..f5ec36df 100644 --- a/server/test/test_utils.py +++ b/server/test/test_utils.py @@ -11,21 +11,23 @@ class DigiScriptTestCase(AsyncHTTPTestCase): def get_app(self): - return DigiScriptServer(debug=True, settings_path=self.settings_path, - skip_migrations=True, skip_migrations_check=True) + return DigiScriptServer( + debug=True, + settings_path=self.settings_path, + skip_migrations=True, + skip_migrations_check=True, + ) def setUp(self): - base_path = os.path.join(os.path.dirname(__file__), 'conf') - settings_path = os.path.join(base_path, 'digiscript.json') + base_path = os.path.join(os.path.dirname(__file__), "conf") + settings_path = os.path.join(base_path, "digiscript.json") self.settings_path = settings_path if not os.path.exists(os.path.dirname(self.settings_path)): os.makedirs(os.path.dirname(self.settings_path)) - with open(self.settings_path, 'w', encoding='UTF-8') as file_pointer: - json.dump({ - 'db_path': 'sqlite://' - }, file_pointer) + with open(self.settings_path, "w", encoding="UTF-8") as file_pointer: + json.dump({"db_path": "sqlite://"}, file_pointer) super().setUp() diff --git a/server/utils/database.py b/server/utils/database.py index e985a70a..c351e43c 100644 --- a/server/utils/database.py +++ b/server/utils/database.py @@ -1,15 +1,15 @@ import functools -from sqlalchemy import MetaData -from sqlalchemy.orm import sessionmaker, declarative_base -from tornado_sqlalchemy import SQLAlchemy, SessionEx, BindMeta +from sqlalchemy import MetaData +from sqlalchemy.orm import declarative_base, sessionmaker +from tornado_sqlalchemy import BindMeta, SessionEx, SQLAlchemy class DeleteMixin: - def pre_delete(self, session: 'DigiDBSession'): + def pre_delete(self, session: "DigiDBSession"): raise NotImplementedError - def post_delete(self, session: 'DigiDBSession'): + def post_delete(self, session: "DigiDBSession"): raise NotImplementedError @@ -30,9 +30,13 @@ def __init__(self, url=None, binds=None, session_options=None, engine_options=No self.sessionmaker = None super().__init__(url, binds, session_options, engine_options) - def configure(self, url=None, binds=None, session_options=None, engine_options=None): + def configure( + self, url=None, binds=None, session_options=None, engine_options=None + ): super().configure(url, binds, session_options, engine_options) - self.sessionmaker = sessionmaker(class_=DigiDBSession, db=self, **(session_options or {})) + self.sessionmaker = sessionmaker( + class_=DigiDBSession, db=self, **(session_options or {}) + ) @functools.lru_cache def get_mapper_for_table(self, tablename): @@ -43,11 +47,11 @@ def get_mapper_for_table(self, tablename): def make_declarative_base(self): convention = { - "ix": 'ix_%(column_0_label)s', + "ix": "ix_%(column_0_label)s", "uq": "uq_%(table_name)s_%(column_0_name)s", "ck": "ck_%(table_name)s_%(constraint_name)s", "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", - "pk": "pk_%(table_name)s" + "pk": "pk_%(table_name)s", } metadata = MetaData(naming_convention=convention) return declarative_base(metaclass=BindMeta, metadata=metadata) diff --git a/server/utils/file_watcher.py b/server/utils/file_watcher.py index ba1917e9..c942ed84 100644 --- a/server/utils/file_watcher.py +++ b/server/utils/file_watcher.py @@ -9,7 +9,7 @@ class FileWatcher: def __init__(self, file_path, callback, poll_interval=500): if not os.path.isfile(file_path): - raise RuntimeError(f'Path {file_path} does not exist') + raise RuntimeError(f"Path {file_path} does not exist") self._file_path = file_path self._poll_interval = poll_interval @@ -34,19 +34,22 @@ class IOLoopFileWatcher(FileWatcher): def __init__(self, file_path, callback, poll_interval=500): if not IOLoop.current(): - raise RuntimeError('No IOLoop found!') + raise RuntimeError("No IOLoop found!") super().__init__(file_path, callback, poll_interval) - self._task: PeriodicCallback = PeriodicCallback(self._poll_file, self._poll_interval, 0.05) + self._task: PeriodicCallback = PeriodicCallback( + self._poll_file, self._poll_interval, 0.05 + ) self._error_callback = None def _poll_file(self): if not (os.path.exists(self._file_path) and os.path.isfile(self._file_path)): if self._error_callback is None: - raise IOError(f'File {self._file_path} could not be found') + raise IOError(f"File {self._file_path} could not be found") - get_logger().warning(f'File {self._file_path} could not be found, calling error ' - f'callback') + get_logger().warning( + f"File {self._file_path} could not be found, calling error " f"callback" + ) self.stop() self._error_callback() @@ -57,7 +60,7 @@ def _poll_file(self): def add_error_callback(self, callback): if not callable(callback): - raise ValueError('`callback` is not callable') + raise ValueError("`callback` is not callable") self._error_callback = callback def watch(self): diff --git a/server/utils/pkg_utils.py b/server/utils/pkg_utils.py index 2c446334..c39937bc 100644 --- a/server/utils/pkg_utils.py +++ b/server/utils/pkg_utils.py @@ -1,5 +1,6 @@ -from pkgutil import iter_modules import sys +from pkgutil import iter_modules + from setuptools import find_packages @@ -7,15 +8,15 @@ def find_modules(path, prefix=None): modules = set() for pkg in find_packages(path): modules.add(pkg) - pkgpath = path + '/' + pkg.replace('.', '/') + pkgpath = path + "/" + pkg.replace(".", "/") if sys.version_info.major == 3 and sys.version_info.minor < 6: for _, name, ispkg in iter_modules([pkgpath]): if not ispkg: - modules.add(pkg + '.' + name) + modules.add(pkg + "." + name) else: for info in iter_modules([pkgpath]): if not info.ispkg: - modules.add(pkg + '.' + info.name) + modules.add(pkg + "." + info.name) if not prefix: return modules diff --git a/server/utils/singleton.py b/server/utils/singleton.py index 15c8d8cc..a4f95811 100644 --- a/server/utils/singleton.py +++ b/server/utils/singleton.py @@ -30,7 +30,7 @@ def instance(self, **kwargs): return self._instance def __call__(self): - raise TypeError('Singletons must be accessed through `instance()`.') + raise TypeError("Singletons must be accessed through `instance()`.") def __instancecheck__(self, inst): return isinstance(inst, self._decorated) diff --git a/server/utils/tree.py b/server/utils/tree.py index 3acd90bf..21ed43a9 100644 --- a/server/utils/tree.py +++ b/server/utils/tree.py @@ -3,11 +3,13 @@ from anytree import Node -def flatten(node: Node, attr: str = 'name') -> List[List]: +def flatten(node: Node, attr: str = "name") -> List[List]: result = [] for child in node.children: if child.is_leaf: result.append([getattr(node, attr), getattr(child, attr)]) else: - result += [[getattr(node, attr)] + subpath for subpath in flatten(child, attr)] + result += [ + [getattr(node, attr)] + subpath for subpath in flatten(child, attr) + ] return result diff --git a/server/utils/web/base_controller.py b/server/utils/web/base_controller.py index aa74ab64..3a134174 100644 --- a/server/utils/web/base_controller.py +++ b/server/utils/web/base_controller.py @@ -1,17 +1,17 @@ from __future__ import annotations -from typing import Optional, Awaitable, Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Awaitable, Optional -from tornado import httputil, escape -from tornado.web import RequestHandler, HTTPError +from tornado import escape, httputil +from tornado.web import HTTPError, RequestHandler from tornado_sqlalchemy import SessionMixin +from digi_server.logger import get_logger from models.models import db from models.show import Show from models.user import User from rbac.role import Role from schemas.schemas import ShowSchema, UserSchema -from digi_server.logger import get_logger if TYPE_CHECKING: from digi_server.app_server import DigiScriptServer @@ -19,27 +19,31 @@ class BaseController(SessionMixin, RequestHandler): - def __init__(self, - application: DigiScriptServer, - request: httputil.HTTPServerRequest, - **kwargs: Any) -> None: + def __init__( + self, + application: DigiScriptServer, + request: httputil.HTTPServerRequest, + **kwargs: Any, + ) -> None: super().__init__(application, request, **kwargs) self.application: DigiScriptServer = self.application self.current_show: Optional[dict] = None - async def prepare(self) -> Optional[Awaitable[None]]: # pylint: disable=invalid-overridden-method + async def prepare( + self, + ) -> Optional[Awaitable[None]]: # pylint: disable=invalid-overridden-method show_schema = ShowSchema() user_schema = UserSchema() with self.make_session() as session: - user_id = self.get_secure_cookie('digiscript_user_id') + user_id = self.get_secure_cookie("digiscript_user_id") if user_id: user = session.query(User).get(int(user_id)) if user: self.current_user = user_schema.dump(user) else: - self.clear_cookie('digiscript_user_id') + self.clear_cookie("digiscript_user_id") - current_show = await self.application.digi_settings.get('current_show') + current_show = await self.application.digi_settings.get("current_show") if current_show: show = session.query(Show).get(current_show) if show: @@ -48,38 +52,41 @@ async def prepare(self) -> Optional[Awaitable[None]]: # pylint: disable=invalid def requires_role(self, resource: db.Model, role: Role): if not self.current_user: - raise HTTPError(401, log_message='Not logged in') - if self.current_user['is_admin']: + raise HTTPError(401, log_message="Not logged in") + if self.current_user["is_admin"]: return with self.make_session() as session: - user = session.query(User).get(self.current_user['id']) + user = session.query(User).get(self.current_user["id"]) if not user: raise HTTPError(500) if not self.application.rbac.has_role(user, resource, role): - raise HTTPError(403, log_message='Not authorised') + raise HTTPError(403, log_message="Not authorised") def get_current_show(self) -> Optional[dict]: return self.current_show def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: - raise RuntimeError( - f'Data streaming not supported for {self.__class__}') + raise RuntimeError(f"Data streaming not supported for {self.__class__}") class BaseAPIController(BaseController): def _unimplemented_method(self, *args: str, **kwargs: str) -> None: self.set_status(405) - self.write({'message': '405 not allowed'}) + self.write({"message": "405 not allowed"}) def on_finish(self): if self.request.body: try: - get_logger().debug(f'{self.request.method} ' - f'{self.request.path} ' - f'{escape.json_decode(self.request.body)}') + get_logger().debug( + f"{self.request.method} " + f"{self.request.path} " + f"{escape.json_decode(self.request.body)}" + ) except BaseException: - get_logger().debug(f'{self.request.method} ' - f'{self.request.path} ' - f'{self.request.body}') + get_logger().debug( + f"{self.request.method} " + f"{self.request.path} " + f"{self.request.body}" + ) super().on_finish() diff --git a/server/utils/web/route.py b/server/utils/web/route.py index 6f154fbe..607d9924 100644 --- a/server/utils/web/route.py +++ b/server/utils/web/route.py @@ -1,9 +1,9 @@ -from enum import Enum import urllib.parse +from enum import Enum from functools import lru_cache -from tornado.web import URLSpec import tornado.escape +from tornado.web import URLSpec from tornado.websocket import WebSocketHandler from utils.web.base_controller import BaseAPIController @@ -41,12 +41,12 @@ def ignored_logging_routes(cls): @staticmethod def _url_escape(url): - return urllib.parse.quote(tornado.escape.utf8(url), '') + return urllib.parse.quote(tornado.escape.utf8(url), "") @classmethod def make(cls, _name, **kwargs): if _name not in cls._formats: - raise KeyError(f'No route by the name of `{_name}`') + raise KeyError(f"No route by the name of `{_name}`") kwargs = {k: cls._url_escape(v) for k, v in kwargs.items()} return cls._formats[_name] % kwargs @@ -63,6 +63,7 @@ def __init__(self, route: str, api_version: ApiVersion, name: str = None): def __call__(self, controller): if not issubclass(controller, (BaseAPIController, WebSocketHandler)): raise RuntimeError( - f'Controller class {controller.__name__} is not an ' - f'instance of BaseAPIController or WebSocketHandler') + f"Controller class {controller.__name__} is not an " + f"instance of BaseAPIController or WebSocketHandler" + ) super().__call__(controller) diff --git a/server/utils/web/web_decorators.py b/server/utils/web/web_decorators.py index d62d94f5..1e8d0144 100644 --- a/server/utils/web/web_decorators.py +++ b/server/utils/web/web_decorators.py @@ -1,5 +1,5 @@ import functools -from typing import Callable, Optional, Awaitable +from typing import Awaitable, Callable, Optional from tornado.web import HTTPError @@ -7,12 +7,12 @@ def requires_show( - method: Callable[..., Optional[Awaitable[None]]] + method: Callable[..., Optional[Awaitable[None]]], ) -> Callable[..., Optional[Awaitable[None]]]: @functools.wraps(method) def wrapper(self: BaseController, *args, **kwargs): if not self.get_current_show(): - raise HTTPError(400, log_message='No show loaded') + raise HTTPError(400, log_message="No show loaded") return method(self, *args, **kwargs) @@ -20,24 +20,25 @@ def wrapper(self: BaseController, *args, **kwargs): def require_admin( - method: Callable[..., Optional[Awaitable[None]]] + method: Callable[..., Optional[Awaitable[None]]], ) -> Callable[..., Optional[Awaitable[None]]]: @functools.wraps(method) def wrapper(self: BaseController, *args, **kwargs): - if not self.current_user or not self.current_user['is_admin']: - raise HTTPError(401, log_message='Not admin user') + if not self.current_user or not self.current_user["is_admin"]: + raise HTTPError(401, log_message="Not admin user") return method(self, *args, **kwargs) + return wrapper def no_live_session( - method: Callable[..., Optional[Awaitable[None]]] + method: Callable[..., Optional[Awaitable[None]]], ) -> Callable[..., Optional[Awaitable[None]]]: @functools.wraps(method) def wrapper(self: BaseController, *args, **kwargs): current_show = self.get_current_show() - if current_show and current_show['current_session_id']: - raise HTTPError(409, log_message='Current session in progress') + if current_show and current_show["current_session_id"]: + raise HTTPError(409, log_message="Current session in progress") return method(self, *args, **kwargs) From e413c7603a7895d6838d64d0cadfc34e4a4fc302 Mon Sep 17 00:00:00 2001 From: Tim Bradgate Date: Wed, 16 Apr 2025 18:41:23 +0100 Subject: [PATCH 03/11] Configure stage direction style from editor --- .../show/config/script/ScriptEditor.vue | 8 +++++-- .../show/config/script/ScriptLineEditor.vue | 23 +++++++++++++++++++ server/controllers/api/show/script.py | 2 ++ server/models/script.py | 1 - 4 files changed, 31 insertions(+), 3 deletions(-) diff --git a/client/src/vue_components/show/config/script/ScriptEditor.vue b/client/src/vue_components/show/config/script/ScriptEditor.vue index e8c59ced..9c8cace2 100644 --- a/client/src/vue_components/show/config/script/ScriptEditor.vue +++ b/client/src/vue_components/show/config/script/ScriptEditor.vue @@ -102,6 +102,7 @@ :previous-line-fn="getPreviousLineForIndex" :next-line-fn="getNextLineForIndex" :is-stage-direction="line.stage_direction" + :stage-direction-styles="STAGE_DIRECTION_STYLES" @input="lineChange(line, index)" @doneEditing="doneEditingLine(currentEditPage, index)" @deleteLine="deleteLine(currentEditPage, index)" @@ -300,6 +301,7 @@ export default { page: null, stage_direction: false, line_parts: [], + stage_direction_style_id: null, }, curSavePage: null, totalSavePages: null, @@ -352,6 +354,8 @@ export default { await this.GET_SCENE_LIST(); await this.GET_CHARACTER_LIST(); await this.GET_CHARACTER_GROUP_LIST(); + // Stage direction styles + await this.GET_STAGE_DIRECTION_STYLES(); // Handle script cuts await this.GET_CUTS(); this.resetCutsToSaved(); @@ -785,7 +789,7 @@ export default { 'SET_CUT_MODE', 'INSERT_BLANK_LINE', 'RESET_INSERTED']), ...mapActions(['GET_SCENE_LIST', 'GET_ACT_LIST', 'GET_CHARACTER_LIST', 'GET_CHARACTER_GROUP_LIST', 'LOAD_SCRIPT_PAGE', 'ADD_BLANK_PAGE', 'GET_SCRIPT_CONFIG_STATUS', - 'RESET_TO_SAVED', 'SAVE_NEW_PAGE', 'SAVE_CHANGED_PAGE', 'GET_CUTS', 'SAVE_SCRIPT_CUTS']), + 'RESET_TO_SAVED', 'SAVE_NEW_PAGE', 'SAVE_CHANGED_PAGE', 'GET_CUTS', 'SAVE_SCRIPT_CUTS', 'GET_STAGE_DIRECTION_STYLES']), }, computed: { canGenerateDebugScript() { @@ -826,7 +830,7 @@ export default { ...mapGetters(['CURRENT_SHOW', 'TMP_SCRIPT', 'ACT_LIST', 'SCENE_LIST', 'CHARACTER_LIST', 'CHARACTER_GROUP_LIST', 'CAN_REQUEST_EDIT', 'CURRENT_EDITOR', 'INTERNAL_UUID', 'GET_SCRIPT_PAGE', 'DEBUG_MODE_ENABLED', 'DELETED_LINES', 'SCENE_BY_ID', 'ACT_BY_ID', - 'IS_CUT_MODE', 'SCRIPT_CUTS', 'INSERTED_LINES']), + 'IS_CUT_MODE', 'SCRIPT_CUTS', 'INSERTED_LINES', 'STAGE_DIRECTION_STYLES']), }, watch: { currentEditPage(val) { diff --git a/client/src/vue_components/show/config/script/ScriptLineEditor.vue b/client/src/vue_components/show/config/script/ScriptLineEditor.vue index 76e81b60..5fb02b49 100644 --- a/client/src/vue_components/show/config/script/ScriptLineEditor.vue +++ b/client/src/vue_components/show/config/script/ScriptLineEditor.vue @@ -72,6 +72,19 @@ @input="stateChange" @addLinePart="addLinePart" /> + + + ({ value: style.id, text: style.description })), + ]; + }, }, }; diff --git a/server/controllers/api/show/script.py b/server/controllers/api/show/script.py index 931cd335..594dc655 100644 --- a/server/controllers/api/show/script.py +++ b/server/controllers/api/show/script.py @@ -510,6 +510,7 @@ async def post(self): scene_id=line["scene_id"], page=line["page"], stage_direction=line["stage_direction"], + stage_direction_style_id=line["stage_direction_style_id"], ) session.add(line_obj) session.flush() @@ -629,6 +630,7 @@ def _create_new_line(session, revision, line, previous_line, with_association=Tr scene_id=line["scene_id"], page=line["page"], stage_direction=line["stage_direction"], + stage_direction_style_id=line["stage_direction_style_id"], ) session.add(line_obj) session.flush() diff --git a/server/models/script.py b/server/models/script.py index 461d6a7c..39d33324 100644 --- a/server/models/script.py +++ b/server/models/script.py @@ -53,7 +53,6 @@ class ScriptLine(db.Model): act = relationship("Act", uselist=False, back_populates="lines") scene = relationship("Scene", uselist=False, back_populates="lines") - stage_direction_style = relationship("StageDirectionStyle", uselist=False) class ScriptLineRevisionAssociation(db.Model, DeleteMixin): From 933e0481c03b561bbc935bc63c83711a83e02214 Mon Sep 17 00:00:00 2001 From: Tim Bradgate Date: Wed, 16 Apr 2025 18:41:51 +0100 Subject: [PATCH 04/11] Fix database upgrades and foreign key deletes --- server/digi_server/app_server.py | 65 +++++++++++++++++++++++--------- server/utils/database.py | 23 +++++++++++ server/utils/exceptions.py | 4 ++ 3 files changed, 74 insertions(+), 18 deletions(-) diff --git a/server/digi_server/app_server.py b/server/digi_server/app_server.py index ef811a69..9e045c8f 100644 --- a/server/digi_server/app_server.py +++ b/server/digi_server/app_server.py @@ -7,6 +7,7 @@ from alembic import command, script from alembic.config import Config from alembic.runtime import migration +from sqlalchemy import Column, String, event from tornado.ioloop import IOLoop from tornado.web import Application, StaticFileHandler from tornado_prometheus import PrometheusMixIn @@ -24,7 +25,7 @@ from rbac.rbac import RBACController from utils.database import DigiSQLAlchemy from utils.env_parser import EnvParser -from utils.exceptions import DatabaseUpgradeRequired +from utils.exceptions import DatabaseTypeException, DatabaseUpgradeRequired from utils.web.route import Route @@ -51,25 +52,53 @@ def __init__( self.clients: List[WebSocketController] = [] self._db: DigiSQLAlchemy = models.db - # Perform database migrations - if not skip_migrations: - self._run_migrations() + + db_path: str = self.digi_settings.settings.get("db_path").get_value() + if db_path.startswith("sqlite:///"): + db_file_path = db_path.replace("sqlite:///", "") + else: + raise DatabaseTypeException("Only SQLite is supported") + # Database set up, if it doesn't exist then create it + # Otherwise attempt to upgrade it + if not os.path.exists(db_file_path): + get_logger().info("Database file does not exist, creating it") + get_logger().info(f"Using {db_path} as DB path") + + class AlembicVersion(self._db.Model): + __tablename__ = "alembic_version" + + version_num = Column(String(32), primary_key=True) + + self._db.configure(url=db_path) + self.rbac = RBACController(self) + self._configure_rbac() + self._db.create_all() + + script_ = script.ScriptDirectory.from_config(self._alembic_config) + current_migration_head = script_.get_current_head() + get_logger().info( + f"Setting current migration head revision: {current_migration_head}" + ) + with self._db.sessionmaker() as session: + session.add(AlembicVersion(version_num=current_migration_head)) + session.commit() else: - get_logger().warning("Skipping performing database migrations") - # And then check the database is up-to-date - if not skip_migrations_check: - self._check_migrations() + # Perform database migrations + if not skip_migrations: + self._run_migrations() else: - get_logger().warning("Skipping database migrations check") - # Finally, configure the database - db_path = self.digi_settings.settings.get("db_path").get_value() - get_logger().info(f"Using {db_path} as DB path") - self._db.configure(url=db_path) - - self.rbac = RBACController(self) - self._configure_rbac() - - self._db.create_all() + get_logger().warning("Skipping performing database migrations") + # And then check the database is up-to-date + if not skip_migrations_check: + self._check_migrations() + else: + get_logger().warning("Skipping database migrations check") + # Finally, configure the database + get_logger().info(f"Using {db_path} as DB path") + self._db.configure(url=db_path) + self.rbac = RBACController(self) + self._configure_rbac() + self._db.create_all() # Clear out all sessions since we are starting the app up with self._db.sessionmaker() as session: diff --git a/server/utils/database.py b/server/utils/database.py index c351e43c..edb57e31 100644 --- a/server/utils/database.py +++ b/server/utils/database.py @@ -28,6 +28,29 @@ class DigiSQLAlchemy(SQLAlchemy): def __init__(self, url=None, binds=None, session_options=None, engine_options=None): self.sessionmaker = None + # For SQLite connections, set up an event listener to enable foreign keys + from sqlalchemy import event + + # Store the original create_engine method + original_create_engine = self.create_engine + + # Override create_engine to add event listener for SQLite + def create_engine_with_fk_support(*args, **kwargs): + engine = original_create_engine(*args, **kwargs) + + # Only add the event listener if it's a SQLite database + if "sqlite" in str(engine.url): + + @event.listens_for(engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + return engine + + # Replace the create_engine method + self.create_engine = create_engine_with_fk_support super().__init__(url, binds, session_options, engine_options) def configure( diff --git a/server/utils/exceptions.py b/server/utils/exceptions.py index 3c8af989..2a42d37f 100644 --- a/server/utils/exceptions.py +++ b/server/utils/exceptions.py @@ -1,2 +1,6 @@ class DatabaseUpgradeRequired(Exception): pass + + +class DatabaseTypeException(Exception): + pass From a19c92a374026715ca4a77d36e244efbc245202d Mon Sep 17 00:00:00 2001 From: Tim Bradgate Date: Wed, 16 Apr 2025 18:46:03 +0100 Subject: [PATCH 05/11] Add black and isort checking to github actions --- .github/workflows/pylint.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 458985e2..aede243d 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -21,6 +21,13 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt + pip install black isort - name: Analysing the code with pylint run: | pylint-ignore $(git ls-files '*.py') + - name: Check code formatting with black + run: | + black --check $(git ls-files '*.py') + - name: Check import sorting with isort + run: | + isort --check $(git ls-files '*.py') --profile=black From 1732c0f46b55cbd6c04c73331176bcd81db33b17 Mon Sep 17 00:00:00 2001 From: Tim Bradgate Date: Wed, 16 Apr 2025 18:48:39 +0100 Subject: [PATCH 06/11] Improve github actions script --- .github/workflows/pylint.yml | 82 +++++++++++++++++++++++++++--------- 1 file changed, 61 insertions(+), 21 deletions(-) diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index aede243d..1fbf0e36 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -1,9 +1,10 @@ -name: Pylint +name: Python Linting and Formatting on: [push] jobs: - build: + pylint: + name: Pylint runs-on: ubuntu-latest defaults: run: @@ -12,22 +13,61 @@ jobs: matrix: python-version: ["3.10"] steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements.txt - pip install black isort - - name: Analysing the code with pylint - run: | - pylint-ignore $(git ls-files '*.py') - - name: Check code formatting with black - run: | - black --check $(git ls-files '*.py') - - name: Check import sorting with isort - run: | - isort --check $(git ls-files '*.py') --profile=black + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Analysing the code with pylint + run: | + pylint-ignore $(git ls-files '*.py') + + black: + name: Black + runs-on: ubuntu-latest + defaults: + run: + working-directory: ./server + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install black + - name: Check code formatting with black + run: | + black --check $(git ls-files '*.py') + + isort: + name: isort + runs-on: ubuntu-latest + defaults: + run: + working-directory: ./server + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install isort + - name: Check import sorting with isort + run: | + isort --check $(git ls-files '*.py') --profile=black \ No newline at end of file From 5c875aef5830487ad2302c4bfbde339471c66fe0 Mon Sep 17 00:00:00 2001 From: Tim Bradgate Date: Wed, 16 Apr 2025 18:52:03 +0100 Subject: [PATCH 07/11] Fix database type check code --- server/digi_server/app_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/digi_server/app_server.py b/server/digi_server/app_server.py index 9e045c8f..1ad6ea32 100644 --- a/server/digi_server/app_server.py +++ b/server/digi_server/app_server.py @@ -54,7 +54,7 @@ def __init__( self._db: DigiSQLAlchemy = models.db db_path: str = self.digi_settings.settings.get("db_path").get_value() - if db_path.startswith("sqlite:///"): + if db_path.startswith("sqlite://"): db_file_path = db_path.replace("sqlite:///", "") else: raise DatabaseTypeException("Only SQLite is supported") From d8e325e592430ec34d58593014985b79957a9d46 Mon Sep 17 00:00:00 2001 From: Tim Bradgate Date: Wed, 16 Apr 2025 18:57:39 +0100 Subject: [PATCH 08/11] Fix database type check code --- server/digi_server/app_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/digi_server/app_server.py b/server/digi_server/app_server.py index 1ad6ea32..a8c9fbb3 100644 --- a/server/digi_server/app_server.py +++ b/server/digi_server/app_server.py @@ -60,7 +60,7 @@ def __init__( raise DatabaseTypeException("Only SQLite is supported") # Database set up, if it doesn't exist then create it # Otherwise attempt to upgrade it - if not os.path.exists(db_file_path): + if db_path.startswith("sqlite:///") and not os.path.exists(db_file_path): get_logger().info("Database file does not exist, creating it") get_logger().info(f"Using {db_path} as DB path") From 03719657361a805073b61be632c749dab6114a8f Mon Sep 17 00:00:00 2001 From: Tim Bradgate Date: Thu, 17 Apr 2025 00:28:40 +0100 Subject: [PATCH 09/11] Add stage direction styles to front end views --- client/src/views/show/ShowLiveView.vue | 6 ++- .../show/config/cues/CueEditor.vue | 7 ++- .../show/config/cues/ScriptLineCueEditor.vue | 47 +++++++++++++++++- .../show/config/script/ScriptEditor.vue | 1 + .../show/config/script/ScriptLineViewer.vue | 47 +++++++++++++++++- .../config/script/StageDirectionStyles.vue | 12 ++--- .../show/live/ScriptLineViewer.vue | 49 ++++++++++++++++++- 7 files changed, 153 insertions(+), 16 deletions(-) diff --git a/client/src/views/show/ShowLiveView.vue b/client/src/views/show/ShowLiveView.vue index 370df0bc..eb1a9f0b 100644 --- a/client/src/views/show/ShowLiveView.vue +++ b/client/src/views/show/ShowLiveView.vue @@ -73,6 +73,7 @@ :cue-types="CUE_TYPES" :cues="getCuesForLine(line)" :cuts="SCRIPT_CUTS" + :stage-direction-styles="STAGE_DIRECTION_STYLES" @last-line-change="handleLastLineChange" @first-line-change="handleFirstLineChange" /> @@ -149,6 +150,7 @@ export default { await this.GET_CUE_TYPES(); await this.LOAD_CUES(); await this.GET_CUTS(); + await this.GET_STAGE_DIRECTION_STYLES(); await this.getMaxScriptPage(); this.updateElapsedTime(); @@ -415,7 +417,7 @@ export default { }, ...mapActions(['GET_SHOW_SESSION_DATA', 'LOAD_SCRIPT_PAGE', 'GET_ACT_LIST', 'GET_SCENE_LIST', 'GET_CHARACTER_LIST', 'GET_CHARACTER_GROUP_LIST', 'LOAD_CUES', 'GET_CUE_TYPES', - 'GET_CUTS']), + 'GET_CUTS', 'GET_STAGE_DIRECTION_STYLES']), }, computed: { pageIter() { @@ -440,7 +442,7 @@ export default { }, ...mapGetters(['CURRENT_SHOW_SESSION', 'GET_SCRIPT_PAGE', 'ACT_LIST', 'SCENE_LIST', 'CHARACTER_LIST', 'CHARACTER_GROUP_LIST', 'CURRENT_SHOW', 'CUE_TYPES', 'SCRIPT_CUES', - 'INTERNAL_UUID', 'SESSION_FOLLOW_DATA', 'SCRIPT_CUTS', 'SETTINGS']), + 'INTERNAL_UUID', 'SESSION_FOLLOW_DATA', 'SCRIPT_CUTS', 'SETTINGS', 'STAGE_DIRECTION_STYLES']), }, watch: { SESSION_FOLLOW_DATA() { diff --git a/client/src/vue_components/show/config/cues/CueEditor.vue b/client/src/vue_components/show/config/cues/CueEditor.vue index d55560f1..0fd49241 100644 --- a/client/src/vue_components/show/config/cues/CueEditor.vue +++ b/client/src/vue_components/show/config/cues/CueEditor.vue @@ -62,6 +62,7 @@ :cue-types="CUE_TYPES" :cues="getCuesForLine(line)" :line-part-cuts="SCRIPT_CUTS" + :stage-direction-styles="STAGE_DIRECTION_STYLES" /> @@ -185,6 +186,7 @@ export default { await this.GET_CUE_TYPES(); await this.LOAD_CUES(); await this.GET_CUTS(); + await this.GET_STAGE_DIRECTION_STYLES(); // Get the max page of the saved version of the script await this.getMaxScriptPage(); @@ -261,7 +263,7 @@ export default { ...mapActions(['GET_SCENE_LIST', 'GET_ACT_LIST', 'GET_CHARACTER_LIST', 'GET_CHARACTER_GROUP_LIST', 'LOAD_SCRIPT_PAGE', 'ADD_BLANK_PAGE', 'GET_SCRIPT_CONFIG_STATUS', 'RESET_TO_SAVED', 'SAVE_NEW_PAGE', 'SAVE_CHANGED_PAGE', 'GET_CUE_TYPES', 'LOAD_CUES', - 'GET_CUTS']), + 'GET_CUTS', 'GET_STAGE_DIRECTION_STYLES']), }, computed: { currentEditPageKey() { @@ -275,7 +277,8 @@ export default { }, ...mapGetters(['CURRENT_SHOW', 'ACT_LIST', 'SCENE_LIST', 'CHARACTER_LIST', 'CHARACTER_GROUP_LIST', 'CAN_REQUEST_EDIT', 'CURRENT_EDITOR', 'INTERNAL_UUID', - 'GET_SCRIPT_PAGE', 'DEBUG_MODE_ENABLED', 'CUE_TYPES', 'SCRIPT_CUES', 'SCRIPT_CUTS']), + 'GET_SCRIPT_PAGE', 'DEBUG_MODE_ENABLED', 'CUE_TYPES', 'SCRIPT_CUES', 'SCRIPT_CUTS', + 'STAGE_DIRECTION_STYLES']), }, watch: { currentEditPage(val) { diff --git a/client/src/vue_components/show/config/cues/ScriptLineCueEditor.vue b/client/src/vue_components/show/config/cues/ScriptLineCueEditor.vue index 439e5550..aa45e420 100644 --- a/client/src/vue_components/show/config/cues/ScriptLineCueEditor.vue +++ b/client/src/vue_components/show/config/cues/ScriptLineCueEditor.vue @@ -44,9 +44,21 @@ > - {{ line.line_parts[0].line_text }} + + + @@ -234,6 +246,10 @@ export default { required: true, type: Array, }, + stageDirectionStyles: { + required: true, + type: Array, + }, }, data() { return { @@ -411,6 +427,33 @@ export default { sceneLabel() { return this.scenes.find((scene) => (scene.id === this.line.scene_id)).name; }, + stageDirectionStyle() { + const sdStyle = this.stageDirectionStyles.find( + (style) => (style.id === this.line.stage_direction_style_id), + ); + if (this.line.stage_direction) { + return sdStyle; + } + return null; + }, + stageDirectionStyling() { + if (this.line.stage_direction_style_id == null || this.stageDirectionStyle == null) { + return { + 'background-color': 'darkslateblue', + 'font-style': 'italic', + }; + } + const style = { + 'font-weight': this.stageDirectionStyle.bold ? 'bold' : 'normal', + 'font-style': this.stageDirectionStyle.italic ? 'italic' : 'normal', + 'text-decoration-line': this.stageDirectionStyle.underline ? 'underline' : 'none', + color: this.stageDirectionStyle.text_colour, + }; + if (this.stageDirectionStyle.enable_background_colour) { + style['background-color'] = this.stageDirectionStyle.background_colour; + } + return style; + }, }, }; diff --git a/client/src/vue_components/show/config/script/ScriptEditor.vue b/client/src/vue_components/show/config/script/ScriptEditor.vue index 9c8cace2..8cf433d0 100644 --- a/client/src/vue_components/show/config/script/ScriptEditor.vue +++ b/client/src/vue_components/show/config/script/ScriptEditor.vue @@ -121,6 +121,7 @@ :can-edit="canEdit" :line-part-cuts="linePartCuts" :insert-mode="insertMode" + :stage-direction-styles="STAGE_DIRECTION_STYLES" @editLine="beginEditingLine(currentEditPage, index)" @cutLinePart="cutLinePart" @insertLine="insertLineAt(currentEditPage, index)" diff --git a/client/src/vue_components/show/config/script/ScriptLineViewer.vue b/client/src/vue_components/show/config/script/ScriptLineViewer.vue index d8a71976..d5f2c6a1 100644 --- a/client/src/vue_components/show/config/script/ScriptLineViewer.vue +++ b/client/src/vue_components/show/config/script/ScriptLineViewer.vue @@ -73,9 +73,21 @@ > - {{ line.line_parts[0].line_text }} + + + @@ -146,6 +158,10 @@ export default { type: Boolean, default: false, }, + stageDirectionStyles: { + required: true, + type: Array, + }, }, computed: { needsHeadings() { @@ -196,6 +212,33 @@ export default { sceneLabel() { return this.scenes.find((scene) => (scene.id === this.line.scene_id)).name; }, + stageDirectionStyle() { + const sdStyle = this.stageDirectionStyles.find( + (style) => (style.id === this.line.stage_direction_style_id), + ); + if (this.line.stage_direction) { + return sdStyle; + } + return null; + }, + stageDirectionStyling() { + if (this.line.stage_direction_style_id == null || this.stageDirectionStyle == null) { + return { + 'background-color': 'darkslateblue', + 'font-style': 'italic', + }; + } + const style = { + 'font-weight': this.stageDirectionStyle.bold ? 'bold' : 'normal', + 'font-style': this.stageDirectionStyle.italic ? 'italic' : 'normal', + 'text-decoration-line': this.stageDirectionStyle.underline ? 'underline' : 'none', + color: this.stageDirectionStyle.text_colour, + }; + if (this.stageDirectionStyle.enable_background_colour) { + style['background-color'] = this.stageDirectionStyle.background_colour; + } + return style; + }, ...mapGetters(['IS_CUT_MODE']), }, methods: { diff --git a/client/src/vue_components/show/config/script/StageDirectionStyles.vue b/client/src/vue_components/show/config/script/StageDirectionStyles.vue index ac9b425f..8e9f7d53 100644 --- a/client/src/vue_components/show/config/script/StageDirectionStyles.vue +++ b/client/src/vue_components/show/config/script/StageDirectionStyles.vue @@ -25,7 +25,7 @@ >

Example Stage Direction

-

@@ -38,7 +38,7 @@ -

+

Configuration Options

@@ -161,7 +161,7 @@ >

Example Stage Direction

-

@@ -174,7 +174,7 @@ -

+

Configuration Options

@@ -289,7 +289,7 @@