@@ -4,6 +4,7 @@ use std::{
44 path:: PathBuf ,
55} ;
66
7+ use crate :: common:: resolve:: AuthUser ;
78use crate :: {
89 common:: resolve:: CurrentUser ,
910 log:: { auth_info, auth_warn} ,
@@ -179,7 +180,7 @@ impl SessionRecordFile {
179180 /// that record time to the current time. This will not create a new record
180181 /// when one is not found. A record will only be updated if it is still
181182 /// valid at this time.
182- pub fn touch ( & mut self , scope : RecordScope , auth_user : UserId ) -> io:: Result < TouchResult > {
183+ pub fn touch ( & mut self , scope : RecordScope , auth_user : & AuthUser ) -> io:: Result < TouchResult > {
183184 // lock the file to indicate that we are currently in a writing operation
184185 let lock = FileLock :: exclusive ( & self . file , false ) ?;
185186 self . seek_to_first_record ( ) ?;
@@ -215,17 +216,12 @@ impl SessionRecordFile {
215216 Ok ( TouchResult :: NotFound )
216217 }
217218
218- /// Disable all records that match the given scope. If an auth user id is
219- /// given then only records with the given scope that are targeting that
220- /// specific user will be disabled.
221- pub fn disable ( & mut self , scope : RecordScope , auth_user : Option < UserId > ) -> io:: Result < ( ) > {
219+ /// Disable all records that match the given scope.
220+ pub fn disable ( & mut self , scope : RecordScope ) -> io:: Result < ( ) > {
222221 let lock = FileLock :: exclusive ( & self . file , false ) ?;
223222 self . seek_to_first_record ( ) ?;
224223 while let Some ( record) = self . next_record ( ) ? {
225- let must_disable = auth_user
226- . map ( |tu| record. matches ( & scope, tu) )
227- . unwrap_or_else ( || record. scope == scope) ;
228- if must_disable {
224+ if record. scope == scope {
229225 self . file . seek ( io:: SeekFrom :: Current ( -SIZE_OF_BOOL ) ) ?;
230226 write_bool ( false , & mut self . file ) ?;
231227 }
@@ -237,7 +233,7 @@ impl SessionRecordFile {
237233 /// Create a new record for the given scope and auth user id.
238234 /// If there is an existing record that matches the scope and auth user,
239235 /// then that record will be updated.
240- pub fn create ( & mut self , scope : RecordScope , auth_user : UserId ) -> io:: Result < CreateResult > {
236+ pub fn create ( & mut self , scope : RecordScope , auth_user : & AuthUser ) -> io:: Result < CreateResult > {
241237 // lock the file to indicate that we are currently writing to it
242238 let lock = FileLock :: exclusive ( & self . file , false ) ?;
243239 self . seek_to_first_record ( ) ?;
@@ -256,7 +252,7 @@ impl SessionRecordFile {
256252 }
257253
258254 // record was not found in the list so far, create a new one
259- let record = SessionRecord :: new ( scope, auth_user) ?;
255+ let record = SessionRecord :: new ( scope, auth_user. uid ) ?;
260256
261257 // make sure we really are at the end of the file
262258 self . file . seek ( io:: SeekFrom :: End ( 0 ) ) ?;
@@ -552,8 +548,8 @@ impl SessionRecord {
552548
553549 /// Returns true if this record matches the specified scope and is for the
554550 /// specified target auth user.
555- pub fn matches ( & self , scope : & RecordScope , auth_user : UserId ) -> bool {
556- self . scope == * scope && self . auth_user == auth_user
551+ pub fn matches ( & self , scope : & RecordScope , auth_user : & AuthUser ) -> bool {
552+ self . scope == * scope && self . auth_user == auth_user. uid
557553 }
558554
559555 /// Returns true if this record was written somewhere in the time range
@@ -566,11 +562,27 @@ impl SessionRecord {
566562
567563#[ cfg( test) ]
568564mod tests {
565+ use std:: path:: Path ;
566+
569567 use super :: * ;
568+ use crate :: common:: { SudoPath , SudoString } ;
569+ use crate :: system:: interface:: GroupId ;
570570 use crate :: system:: tests:: tempfile;
571+ use crate :: system:: User ;
571572
572573 static TEST_USER_ID : UserId = UserId :: ROOT ;
573574
575+ fn auth_user_from_uid ( uid : libc:: uid_t ) -> AuthUser {
576+ AuthUser :: from_user_for_targetpw ( User {
577+ uid : UserId :: new ( uid) ,
578+ gid : GroupId :: new ( 0 ) ,
579+ name : SudoString :: new ( "dummy" . to_owned ( ) ) . unwrap ( ) ,
580+ home : SudoPath :: new ( Path :: new ( "/nonexistent" ) . to_owned ( ) ) . unwrap ( ) ,
581+ shell : Path :: new ( "/bin/sh" ) . to_owned ( ) ,
582+ groups : vec ! [ ] ,
583+ } )
584+ }
585+
574586 #[ test]
575587 fn can_encode_and_decode ( ) {
576588 let tty_sample = SessionRecord :: new (
@@ -618,22 +630,22 @@ mod tests {
618630
619631 let tty_sample = SessionRecord :: new ( scope, UserId :: new ( 675 ) ) . unwrap ( ) ;
620632
621- assert ! ( tty_sample. matches( & scope, UserId :: new ( 675 ) ) ) ;
622- assert ! ( !tty_sample. matches( & scope, UserId :: new ( 789 ) ) ) ;
633+ assert ! ( tty_sample. matches( & scope, & auth_user_from_uid ( 675 ) ) ) ;
634+ assert ! ( !tty_sample. matches( & scope, & auth_user_from_uid ( 789 ) ) ) ;
623635 assert ! ( !tty_sample. matches(
624636 & RecordScope :: Tty {
625637 tty_device: DeviceId :: new( 20 ) ,
626638 session_pid: ProcessId :: new( 1234 ) ,
627639 init_time
628640 } ,
629- UserId :: new ( 675 ) ,
641+ & auth_user_from_uid ( 675 ) ,
630642 ) ) ;
631643 assert ! ( !tty_sample. matches(
632644 & RecordScope :: Ppid {
633645 group_pid: ProcessId :: new( 42 ) ,
634646 init_time
635647 } ,
636- UserId :: new ( 675 ) ,
648+ & auth_user_from_uid ( 675 ) ,
637649 ) ) ;
638650
639651 // make sure time is different
@@ -644,7 +656,7 @@ mod tests {
644656 session_pid: ProcessId :: new( 1234 ) ,
645657 init_time: ProcessCreateTime :: new( 1 , 1 )
646658 } ,
647- UserId :: new ( 675 ) ,
659+ & auth_user_from_uid ( 675 ) ,
648660 ) ) ;
649661 }
650662
@@ -721,22 +733,22 @@ mod tests {
721733 session_pid : ProcessId :: new ( 0 ) ,
722734 init_time : ProcessCreateTime :: new ( 0 , 0 ) ,
723735 } ;
724- let auth_user = UserId :: new ( 2424 ) ;
725- let res = srf. create ( tty_scope, auth_user) . unwrap ( ) ;
736+ let auth_user = auth_user_from_uid ( 2424 ) ;
737+ let res = srf. create ( tty_scope, & auth_user) . unwrap ( ) ;
726738 let CreateResult :: Created { time } = res else {
727739 panic ! ( "Expected record to be created" ) ;
728740 } ;
729741
730742 std:: thread:: sleep ( std:: time:: Duration :: from_millis ( 1 ) ) ;
731- let second = srf. touch ( tty_scope, auth_user) . unwrap ( ) ;
743+ let second = srf. touch ( tty_scope, & auth_user) . unwrap ( ) ;
732744 let TouchResult :: Updated { old_time, new_time } = second else {
733745 panic ! ( "Expected record to be updated" ) ;
734746 } ;
735747 assert_eq ! ( time, old_time) ;
736748 assert_ne ! ( old_time, new_time) ;
737749
738750 std:: thread:: sleep ( std:: time:: Duration :: from_millis ( 1 ) ) ;
739- let res = srf. create ( tty_scope, auth_user) . unwrap ( ) ;
751+ let res = srf. create ( tty_scope, & auth_user) . unwrap ( ) ;
740752 let CreateResult :: Updated { old_time, new_time } = res else {
741753 panic ! ( "Expected record to be updated" ) ;
742754 } ;
0 commit comments