first commit

This commit is contained in:
Lorenzo Iovino 2025-05-23 14:45:22 +02:00
commit b6d47982c7
8 changed files with 1353 additions and 0 deletions

398
parallel_transcript.py Normal file
View file

@ -0,0 +1,398 @@
import os
import argparse
import whisper
import torch
import time
import threading
import json
from pyannote.audio import Pipeline
from datetime import timedelta
import numpy as np
from pydub import AudioSegment
import math
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
def start_spinner():
def spin():
while not spinner_done:
print(".", end="", flush=True)
time.sleep(1)
global spinner_done
spinner_done = False
t = threading.Thread(target=spin)
t.start()
return t
def stop_spinner(thread):
global spinner_done
spinner_done = True
thread.join()
print("")
def extend_audio_beginning(input_audio, output_audio, silence_duration=0):
"""Aggiunge un breve silenzio all'inizio dell'audio per catturare meglio i primi secondi"""
print(f"🔄 Aggiungendo {silence_duration/1000} secondi di silenzio all'inizio dell'audio...")
audio = AudioSegment.from_file(input_audio)
silence = AudioSegment.silent(duration=silence_duration) # 2 secondi di silenzio
extended_audio = silence + audio
extended_audio.export(output_audio, format="wav")
print(f"✅ Audio esteso salvato come {output_audio}")
return output_audio
def transcribe_audio(audio_path, model_size="large"):
"""Trascrivi l'audio con Whisper usando impostazioni ottimizzate"""
print(f"🔹 Trascrizione con Whisper ({model_size})...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"⚙️ Uso dispositivo: {device.upper()}")
model = whisper.load_model(model_size).to(device)
# Impostazioni avanzate per migliorare il rilevamento del discorso
options = {
"language": "it",
"condition_on_previous_text": True, # Migliora la coerenza tra segmenti
"suppress_tokens": [-1], # Sopprime i tokens di silenzio
"initial_prompt": "Trascrizione di una conversazione tra tre persone." # Contestualizza
}
# Verifica se la versione di Whisper supporta word_timestamps
try:
test_options = options.copy()
test_options["word_timestamps"] = True
whisper.transcribe(audio_path, **test_options)
options["word_timestamps"] = True
print("✅ Utilizzo timestamp a livello di parola")
except:
print("⚠️ Questa versione di Whisper non supporta i timestamp a livello di parola")
spinner = start_spinner()
start = time.time()
result = model.transcribe(audio_path, **options)
stop_spinner(spinner)
duration = time.time() - start
print(f"✅ Trascrizione completata in {round(duration, 2)} secondi")
# Salva anche i timestamp delle parole per post-processing
with open(f"{TRANSCRIPT_FILE}.words.json", "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
with open(TRANSCRIPT_FILE, "w", encoding="utf-8") as f:
f.write(result["text"])
return result["segments"]
def diarize_audio(audio_path, hf_token, num_speakers=None):
"""Diarizzazione audio con parametri ottimizzati per sovrapposizioni"""
print("🔹 Riconoscimento speaker (v3.1) con Pyannote...")
# Carica il modello senza tentare di modificare i parametri
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=hf_token)
spinner = start_spinner()
start = time.time()
# Utilizzo del numero di speaker se specificato
if num_speakers is not None:
print(f" Utilizzo {num_speakers} speaker come specificato")
diarization = pipeline(audio_path, num_speakers=num_speakers)
else:
diarization = pipeline(audio_path)
stop_spinner(spinner)
duration = time.time() - start
print(f"✅ Speaker identificati in {round(duration, 2)} secondi")
# Analizza gli speaker identificati
speakers = set()
for segment, _, speaker in diarization.itertracks(yield_label=True):
speakers.add(speaker)
print(f"👥 Identificati {len(speakers)} speaker: {', '.join(sorted(speakers))}")
# Salva la diarizzazione grezza per ispezione
with open(f"{OUTPUT_FILE}.diarization.json", "w", encoding="utf-8") as f:
segments = []
for segment, _, speaker in diarization.itertracks(yield_label=True):
segments.append({
"start": segment.start,
"end": segment.end,
"speaker": speaker
})
json.dump(segments, f, indent=2)
return diarization
def format_time(seconds, srt=False):
"""Formatta il tempo in formato leggibile"""
td = timedelta(seconds=float(seconds))
hours, remainder = divmod(td.seconds, 3600)
minutes, seconds = divmod(remainder, 60)
milliseconds = round(td.microseconds / 1000)
if srt:
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
else:
return f"{hours:d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
def find_overlapping_speech(diarization, threshold=0.5):
"""Identifica segmenti con sovrapposizione di parlato"""
overlap_segments = []
speaker_segments = {}
# Organizza i segmenti per speaker
for segment, _, speaker in diarization.itertracks(yield_label=True):
if speaker not in speaker_segments:
speaker_segments[speaker] = []
speaker_segments[speaker].append((segment.start, segment.end))
# Trova sovrapposizioni tra speaker diversi
speakers = list(speaker_segments.keys())
for i in range(len(speakers)):
for j in range(i+1, len(speakers)):
speaker1 = speakers[i]
speaker2 = speakers[j]
for seg1_start, seg1_end in speaker_segments[speaker1]:
for seg2_start, seg2_end in speaker_segments[speaker2]:
# Controlla se i segmenti si sovrappongono
if seg1_start < seg2_end and seg2_start < seg1_end:
overlap_start = max(seg1_start, seg2_start)
overlap_end = min(seg1_end, seg2_end)
overlap_duration = overlap_end - overlap_start
if overlap_duration >= threshold:
overlap_segments.append({
"start": overlap_start,
"end": overlap_end,
"speakers": [speaker1, speaker2],
"duration": overlap_duration
})
# Combina sovrapposizioni vicine
if overlap_segments:
overlap_segments.sort(key=lambda x: x["start"])
merged = [overlap_segments[0]]
for current in overlap_segments[1:]:
previous = merged[-1]
if current["start"] - previous["end"] < 0.5: # Meno di mezzo secondo di distanza
# Unisci gli intervalli
previous["end"] = max(previous["end"], current["end"])
previous["speakers"] = list(set(previous["speakers"] + current["speakers"]))
previous["duration"] = previous["end"] - previous["start"]
else:
merged.append(current)
overlap_segments = merged
return overlap_segments
def match_transcript_to_speakers(transcript_segments, diarization, min_segment_length=1.0, max_chars=150):
"""Abbina la trascrizione agli speaker con gestione migliorata delle sovrapposizioni"""
print("🔹 Combinazione transcript + speaker...")
# Trova le potenziali sovrapposizioni
overlaps = find_overlapping_speech(diarization)
if overlaps:
print(f" Rilevate {len(overlaps)} potenziali sovrapposizioni di parlato")
with open(f"{OUTPUT_FILE}.overlaps.json", "w", encoding="utf-8") as f:
json.dump(overlaps, f, indent=2)
# Combina segmenti brevi dello stesso speaker
combined_segments = []
current_segment = None
for segment, _, speaker in diarization.itertracks(yield_label=True):
start_time = round(segment.start, 2)
end_time = round(segment.end, 2)
# Skip se il segmento è troppo breve
if end_time - start_time < 0.2:
continue
# Se è il primo segmento o c'è un cambio di speaker
if current_segment is None or current_segment["speaker"] != speaker:
if current_segment is not None:
combined_segments.append(current_segment)
current_segment = {
"start": start_time,
"end": end_time,
"speaker": speaker
}
else:
# Estendi il segmento corrente
current_segment["end"] = end_time
# Aggiungi l'ultimo segmento
if current_segment is not None:
combined_segments.append(current_segment)
# Ora abbina il testo ai segmenti combinati
output_segments = []
counter = 1
for segment in combined_segments:
start_time = segment["start"]
end_time = segment["end"]
speaker = segment["speaker"]
# Skip segmenti troppo brevi dopo la combinazione
if end_time - start_time < min_segment_length:
continue
# Trova il testo che corrisponde a questo intervallo di tempo
text = ""
for s in transcript_segments:
# Se il segmento di testo si sovrappone al segmento speaker
if (s["start"] < end_time and s["end"] > start_time):
text += s["text"] + " "
text = text.strip()
if text:
# Controlla se questo segmento è in una sovrapposizione
is_overlap = False
overlap_speakers = []
for overlap in overlaps:
if (start_time < overlap["end"] and end_time > overlap["start"]):
is_overlap = True
overlap_speakers = overlap["speakers"]
break
# Formatta l'output in base alla presenza di sovrapposizione
if is_overlap and speaker in overlap_speakers:
speaker_text = f"[{speaker}+] " if len(overlap_speakers) > 1 else f"[{speaker}] "
else:
speaker_text = f"[{speaker}] "
# Crea il segmento completo con testo
output_segment = {
"start": start_time,
"end": end_time,
"speaker": speaker,
"speaker_text": speaker_text,
"text": text
}
output_segments.append(output_segment)
# Ora formatta e salva l'output finale
output = []
srt = []
vtt = ["WEBVTT\n"]
for i, segment in enumerate(output_segments, 1):
start_time = segment["start"]
end_time = segment["end"]
speaker_text = segment["speaker_text"]
text = segment["text"]
formatted_text = f"{speaker_text}({format_time(start_time)} - {format_time(end_time)}): {text}"
srt_text = f"{counter}\n{format_time(start_time, True)} --> {format_time(end_time, True)}\n{speaker_text}{text}"
vtt_text = f"{format_time(start_time)} --> {format_time(end_time)}\n{speaker_text}{text}"
output.append(formatted_text)
srt.append(srt_text)
vtt.append(vtt_text)
counter += 1
with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
f.write("\n".join(output))
with open(SRT_FILE, "w", encoding="utf-8") as f:
f.write("\n\n".join(srt))
with open(VTT_FILE, "w", encoding="utf-8") as f:
f.write("\n\n".join(vtt))
print("✅ Output finale salvato:", OUTPUT_FILE, SRT_FILE, VTT_FILE)
def parse_timestamp(time_str):
"""Convert a timestamp string to seconds"""
# Handle both SRT (00:00:00,000) and standard format (00:00:00.000)
time_str = time_str.replace(',', '.')
hours, minutes, seconds = time_str.split(':')
hours = int(hours)
minutes = int(minutes)
seconds = float(seconds)
total_seconds = hours * 3600 + minutes * 60 + seconds
return total_seconds
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Trascrizione + Speaker Diarization avanzata")
parser.add_argument("--audio", help="File audio WAV", required=True)
parser.add_argument("--token", help="Token Hugging Face per Pyannote")
parser.add_argument("--model", default="large", help="Modello Whisper (tiny, base, medium, large)")
parser.add_argument("--no-diarization", action="store_true", help="Disabilita il riconoscimento speaker")
parser.add_argument("--output-prefix", default="transcript", help="Prefisso per i file di output")
parser.add_argument("--num-speakers", type=int, default=None, help="Numero di speaker se conosciuto in anticipo")
parser.add_argument("--fix-start", action="store_true", help="Aggiungi silenzio all'inizio per catturare meglio i primi secondi")
parser.add_argument("--min-segment", type=float, default=1.0, help="Lunghezza minima dei segmenti in secondi")
args = parser.parse_args()
# Use Hugging Face token from environment variable if not provided via argument
hf_token = args.token or os.getenv("HF_TOKEN")
if not hf_token:
raise ValueError("Token Hugging Face non fornito. Specificarlo con --token o nella variabile HF_TOKEN nel file .env")
# Use model from environment variable if available
model_size = os.getenv("WHISPER_MODEL", args.model)
# Use number of speakers from environment variable if available and not provided via argument
num_speakers = args.num_speakers
if num_speakers is None and os.getenv("NUM_SPEAKERS"):
try:
num_speakers = int(os.getenv("NUM_SPEAKERS"))
except ValueError:
pass
# Use fix-start from environment variable if available and not provided via argument
fix_start = args.fix_start
if not fix_start and os.getenv("FIX_START", "").lower() == "true":
fix_start = True
# Definizione dei nomi dei file di output
output_prefix = os.getenv("OUTPUT_PREFIX", args.output_prefix)
TRANSCRIPT_FILE = f"{output_prefix}.txt"
OUTPUT_FILE = f"{output_prefix}_final.txt"
SRT_FILE = f"{output_prefix}.srt"
VTT_FILE = f"{output_prefix}.vtt"
if not os.path.exists(args.audio):
raise ValueError(f"File audio {args.audio} non trovato")
# Aggiungi silenzio all'inizio se richiesto
input_audio = args.audio
if fix_start:
extended_audio = "extended_" + os.path.basename(args.audio)
input_audio = extend_audio_beginning(args.audio, extended_audio)
# Trascrivi l'audio
segments = transcribe_audio(input_audio, model_size)
# Esegui diarizzazione e abbina trascrizione a speaker
if not args.no_diarization:
diarization = diarize_audio(input_audio, hf_token, num_speakers)
match_transcript_to_speakers(segments, diarization, args.min_segment)
else:
print("🛑 Diarization disabilitata. Salvo solo la trascrizione.")
with open(OUTPUT_FILE, "w", encoding="utf-8") as f_out, open(SRT_FILE, "w", encoding="utf-8") as f_srt, open(VTT_FILE, "w", encoding="utf-8") as f_vtt:
f_vtt.write("WEBVTT\n\n")
for i, s in enumerate(segments, 1):
start = format_time(s['start'])
end = format_time(s['end'])
f_out.write(f"({start} - {end}): {s['text'].strip()}\n")
f_srt.write(f"{i}\n{format_time(s['start'], True)} --> {format_time(s['end'], True)}\n{s['text'].strip()}\n\n")
f_vtt.write(f"{start} --> {end}\n{s['text'].strip()}\n\n")
print(f"✅ Output salvato senza diarizzazione: {OUTPUT_FILE}, {SRT_FILE}, {VTT_FILE}")
# Rimuovi file audio esteso se creato
if args.fix_start and os.path.exists(extended_audio):
os.remove(extended_audio)