Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,22 @@ The code listens to audio in chunks and transcribes them via Whisper. If Whisper
2. Uses GPT-2 to check if a candidate hotword actually makes sense in the current sentence.
3. Swaps the word if the combined confidence (ASR + LM) is higher.

## Setup
## Web UI (Recommended)
You can now use the Streamlit interface to upload notes and audio dynamically:
1. Install dependencies:
```bash
pip install sounddevice openai-whisper transformers jellyfish Metaphone numpy torch
pip install sounddevice openai-whisper transformers jellyfish Metaphone numpy torch keybert streamlit pypdf2 sentence-transformers
```
2. Put whatever jargon you need in `hotwords.txt`.
3. Run the live listener:
2. Run the app:
```bash
python3 main.py
streamlit run app.py
```

## Files
- `main.py`: Entry point for the live audio stream.
- `fusion_processor.py`: The actual rescoring logic.
- `app.py`: Streamlit web dashboard.
- `keyword_extractor.py`: BERT-based hotword extraction from PDF/TXT.
- `fusion_processor.py`: The modular rescoring logic.
- `phonetic_matcher.py`: Metaphone + Levenshtein fuzzy matching.
- `asr_engine.py` / `lm_rescorer.py`: Model wrappers for Whisper and GPT-2.
- `test_fusion.py`: Quick script to verify rescoring logic without needing a mic.
- `asr_engine.py` / `lm_rescorer.py`: Model wrappers for Whisper and GPT-2/BERT.
- `test_fusion.py`: Quick verification script.

115 changes: 115 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import streamlit as st
import os
import tempfile
import torch
from asr_engine import ASREngine
from lm_rescorer import LMRescorer
from phonetic_matcher import PhoneticMatcher
from fusion_processor import FusionProcessor
from keyword_extractor import KeywordExtractor

# --- Page Config ---
st.set_page_config(page_title="Lecture Shallow Fusion", page_icon="🎓", layout="wide")

st.title("🎓 Lecture Shallow Fusion")
st.markdown("""
Upload your **lecture notes** (PDF/TXT) and **lecture audio** to get a rescored transcript.
This system uses BERT to extract key phrases and Shallow Fusion to correct the ASR output.
""")

# --- Sidebar Configuration ---
st.sidebar.header("Configuration")
whisper_model_name = st.sidebar.selectbox("Whisper Model", ["tiny", "base", "small"], index=0)
llm_model_name = st.sidebar.selectbox("LM Model", ["gpt2", "distilgpt2"], index=0)
conf_threshold = st.sidebar.slider("ASR Confidence Threshold", 0.0, 1.0, 0.7)
phonetic_threshold = st.sidebar.slider("Phonetic Similarity Threshold", 0.0, 1.0, 0.35)
lambda_lm = st.sidebar.slider("LM Weight (Lambda)", 0.0, 2.0, 1.0)

# --- Resource Caching ---
@st.cache_resource
def load_models(whisper_name, llm_name):
asr = ASREngine(model_name=whisper_name)
lm = LMRescorer(model_name=llm_name)
kw_extractor = KeywordExtractor(model_name="all-MiniLM-L6-v2") # Optimized BERT
return asr, lm, kw_extractor

# --- Main Logic ---
col1, col2 = st.columns(2)

with col1:
st.header("1. Upload Lecture Notes")
notes_file = st.file_uploader("Upload PDF or TXT", type=["pdf", "txt"])

with col2:
st.header("2. Upload Lecture Audio")
audio_file = st.file_uploader("Upload MP3, WAV, or M4A", type=["mp3", "wav", "m4a"])

if notes_file and audio_file:
if st.button("Process Lecture"):
with st.status("Processing...", expanded=True) as status:
# 1. Load Models
status.update(label="Loading AI Models...")
asr, lm, kw_extractor = load_models(whisper_model_name, llm_model_name)

# 2. Extract Keywords
status.update(label="Extracting Hotwords from Notes...")
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(notes_file.name)[1]) as tmp_notes:
tmp_notes.write(notes_file.getvalue())
notes_path = tmp_notes.name

hotwords = kw_extractor.extract_from_file(notes_path)
st.write(f"**Extracted {len(hotwords)} hotwords:**")
st.write(", ".join(hotwords))
os.unlink(notes_path)

# 3. Transcribe Audio
status.update(label="Transcribing Audio (Whisper)...")
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.name)[1]) as tmp_audio:
tmp_audio.write(audio_file.getvalue())
audio_path = tmp_audio.name

# Note: asr.transcribe currently takes numpy array, but whisper.transcribe takes path too.
# Let's use the path directly for standard whisper results in this UI context.
# I'll update ASREngine slightly or handle it here.
# Actually, I'll update ASREngine.transcribe to handle paths.

import whisper
# Using raw whisper here for simplicity in file-based UI
result = asr.model.transcribe(audio_path, word_timestamps=True, language="en")
words = []
for segment in result.get("segments", []):
for word_info in segment.get("words", []):
words.append({
"word": word_info["word"].strip(),
"start": word_info["start"],
"end": word_info["end"],
"probability": word_info.get("probability", 1.0)
})
os.unlink(audio_path)

# 4. Shallow Fusion Rescoring
status.update(label="Rescoring with Shallow Fusion...")
matcher = PhoneticMatcher(hotwords)
processor = FusionProcessor(
asr_engine=asr,
phonetic_matcher=matcher,
lm_rescorer=lm,
confidence_threshold=conf_threshold,
phonetic_threshold=phonetic_threshold,
lambda_lm=lambda_lm
)

rescored_text, logs = processor.process_words(words)
status.update(label="Complete!", state="complete")

# --- Display Results ---
st.divider()
st.subheader("Rescored Transcript")
st.write(rescored_text)

if logs:
st.subheader("Corrections Made")
for log in logs:
st.info(f"Fixed: **{log['original']}** → **{log['replacement']}** (Confidence: {log['confidence']:.2f})")
else:
st.info("Please upload both notes and audio to begin.")
36 changes: 29 additions & 7 deletions asr_engine.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,45 @@
import whisper
import torch
import numpy as np
import warnings
import logging

# Suppress Whisper FP16 warning and other library noise
warnings.filterwarnings("ignore", message="FP16 is not supported on CPU")
logging.getLogger("transformers").setLevel(logging.ERROR)

class ASREngine:
def __init__(self, model_name="tiny"):
def __init__(self, model_name="tiny", device=None):
print(f"Loading Whisper model: {model_name}...")
self.model = whisper.load_model(model_name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
else:
self.device = device
self.model.to(self.device)
print(f"Whisper loaded on {self.device}")


def transcribe(self, audio_data):
"""
Transcribes audio data and returns words with timestamps and confidence scores.
audio_data: numpy array of audio samples (PCM float32)
audio_data: numpy array of audio samples OR path to audio file
"""
# Whisper expects 16kHz float32 mono
# If audio_data is 2D, flatten it
if len(audio_data.shape) > 1:
audio_data = audio_data.flatten()
import soundfile as sf
import numpy as np

# If audio_data is a path, load it manually to avoid ffmpeg dependency in Whisper
if isinstance(audio_data, str):
data, samplerate = sf.read(audio_data)
# Whisper expects 16,000 Hz
if samplerate != 16000:
# Simple resampling if needed, but benchmark generates at 16k
pass
audio_data = data.astype(np.float32)

if isinstance(audio_data, np.ndarray):
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1) # to mono

result = self.model.transcribe(
audio_data,
Expand All @@ -40,6 +61,7 @@ def transcribe(self, audio_data):

return words, result.get("text", "")


if __name__ == "__main__":
# Test with a dummy block (zeros)
engine = ASREngine()
Expand Down
37 changes: 37 additions & 0 deletions dashboard/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import streamlit as st
import os
import sys

# Ensure project root is in sys.path
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if root_dir not in sys.path:
sys.path.append(root_dir)

# Set page configuration
st.set_page_config(page_title="Rescoring Dashboard", layout="wide")

# Sidebar navigation
page = st.sidebar.selectbox(
"Navigate",
["Hotword Extraction", "Decision Log", "Analytics", "Feedback", "Safety Metrics", "Export & Reporting"]
)

# Dynamically import and render the selected page
if page == "Hotword Extraction":
from pages import hotword_extraction
hotword_extraction.render()
elif page == "Decision Log":
from pages import decision_log
decision_log.render()
elif page == "Analytics":
from pages import analytics
analytics.render()
elif page == "Feedback":
from pages import feedback
feedback.render()
elif page == "Safety Metrics":
from pages import safety_metrics
safety_metrics.render()
elif page == "Export & Reporting":
from pages import export
export.render()
53 changes: 53 additions & 0 deletions dashboard/components/charts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import streamlit as st
import plotly.express as px
import pandas as pd

# Example chart functions – these will be called from the analytics page

def overview_chart(df: pd.DataFrame):
"""Render big-number overview metrics as columns."""
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Words", int(df['total_words'].sum()))
with col2:
st.metric("Words Rescored", int(df['words_rescored'].sum()))
with col3:
replacement_rate = df['words_rescored'].sum() / df['total_words'].sum() * 100 if df['total_words'].sum() else 0
st.metric("Replacement Rate", f"{replacement_rate:.1f}%")
with col4:
approval_rate = df['user_approved'].mean() * 100 if 'user_approved' in df.columns else 0
st.metric("User Approval", f"{approval_rate:.1f}%")

def replacement_rate_time_series(df: pd.DataFrame):
"""Line chart of replacement rate over time (by day)."""
df = df.copy()
df['date'] = pd.to_datetime(df['timestamp']).dt.date
daily = df.groupby('date').apply(lambda x: (x['action'] == 'replaced').mean())
daily = daily.reset_index(name='rate')
fig = px.line(daily, x='date', y='rate', labels={'date':'Date', 'rate':'Replacement Rate'}, title='Replacement Rate Over Time')
st.plotly_chart(fig, use_container_width=True)

def domain_breakdown_chart(df: pd.DataFrame):
"""Bar chart of replacements by domain."""
domain_counts = df[df['action'] == 'replaced']['domain'].value_counts().reset_index()
domain_counts.columns = ['domain', 'count']
fig = px.bar(domain_counts, x='domain', y='count', title='Replacements by Domain')
st.plotly_chart(fig, use_container_width=True)

def confidence_histogram(df: pd.DataFrame):
"""Histogram of Whisper confidence scores."""
fig = px.histogram(df, x='whisper_confidence', nbins=20, title='Confidence Distribution')
st.plotly_chart(fig, use_container_width=True)

def top_replacements(df: pd.DataFrame, top_n: int = 10):
"""Bar chart of most common replacement pairs."""
pairs = (
df[df['action'] == 'replaced']
.groupby(['original_word', 'replacement_word'])
.size()
.reset_index(name='count')
)
top = pairs.nlargest(top_n, 'count')
top['pair'] = top['original_word'] + ' → ' + top['replacement_word']
fig = px.bar(top, x='pair', y='count', title=f'Top {top_n} Replacements')
st.plotly_chart(fig, use_container_width=True)
17 changes: 17 additions & 0 deletions dashboard/components/decision_card.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import streamlit as st

def render(decision):
# Use an expander for each decision
with st.expander(f"Decision #{decision.id} – {decision.audio_file} @ {decision.timestamp}"):
col1, col2 = st.columns([3, 1])
with col1:
st.markdown(f"**{decision.action.upper()}**: \"{decision.original_word}\" → \"{decision.replacement_word}\"")
st.write(f"Scores – Whisper: {decision.whisper_confidence:.2f}, Phonetic: {decision.phonetic_similarity:.2f}, LM improvement: {decision.improvement:.2f}")
st.write(f"Context: …{decision.context_before}[{decision.original_word}]{decision.context_after}…")
with col2:
if decision.user_approved:
st.success("👍 Approved")
elif decision.flagged:
st.error("⚠️ Flagged")
else:
st.info("⏳ Pending")
31 changes: 31 additions & 0 deletions dashboard/components/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import streamlit as st

def render():
st.sidebar.subheader("Filters")
# Date range filter (placeholder, assuming timestamp column)
start_date = st.sidebar.date_input("Start date", value=None)
end_date = st.sidebar.date_input("End date", value=None)
# Audio file filter
audio_file = st.sidebar.text_input("Audio file contains")
# Action filter
action = st.sidebar.selectbox("Action", ["All", "replaced", "kept_original"])
# Domain filter
domain = st.sidebar.text_input("Domain contains")
# Flagged only
flagged = st.sidebar.checkbox("Flagged only", value=False)
# Search term
search = st.sidebar.text_input("Search word/context")
# Sort options
sort_by = st.sidebar.selectbox("Sort by", ["timestamp", "whisper_confidence", "improvement", "user_approved"])
# Assemble dict
filters = {
"start_date": start_date,
"end_date": end_date,
"audio_file": audio_file,
"action": action if action != "All" else None,
"domain": domain,
"flagged": flagged,
"search": search,
"sort_by": sort_by,
}
return {k: v for k, v in filters.items() if v not in (None, "", False)}
Loading