fix: openai transcription model response

This commit is contained in:
Chidi Williams 2024-01-11 09:35:56 +00:00
parent 43c3f66aa7
commit 397dadd7a2
2 changed files with 20 additions and 6 deletions

View file

@ -103,7 +103,7 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
response_format="verbose_json",
language=self.transcription_task.transcription_options.language,
)
if self.transcription_task.transcription_options.task == Task.TRANSLATE
if self.transcription_task.transcription_options.task == Task.TRANSCRIBE
else self.openai_client.audio.translations.create(
model="whisper-1", file=file, response_format="verbose_json"
)
@ -115,7 +115,7 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
int(segment["end"] * 1000 + offset_ms),
segment["text"],
)
for segment in transcript["segments"]
for segment in transcript.model_extra["segments"]
]
def stop(self):

View file

@ -12,6 +12,8 @@ from buzz.transcriber.transcriber import (
FileTranscriptionOptions,
)
from openai.types.audio import Transcription, Translation
class TestOpenAIWhisperAPIFileTranscriber:
@pytest.fixture
@ -19,13 +21,23 @@ class TestOpenAIWhisperAPIFileTranscriber:
with patch(
"buzz.transcriber.openai_whisper_api_file_transcriber.OpenAI"
) as mock:
return_value = {"segments": [{"start": 0, "end": 6.56, "text": "Hello"}]}
mock.return_value.audio.transcriptions.create.return_value = return_value
mock.return_value.audio.translations.create.return_value = return_value
return_value = {
"text": "",
"segments": [{"start": 0, "end": 6.56, "text": "Hello"}],
}
mock.return_value.audio.transcriptions.create.return_value = Transcription(
**return_value
)
mock.return_value.audio.translations.create.return_value = Translation(
**return_value
)
yield mock
def test_transcribe(self, mock_openai_client):
file_path = "testdata/whisper-french.mp3"
file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"../../testdata/whisper-french.mp3",
)
transcriber = OpenAIWhisperAPIFileTranscriber(
task=FileTranscriptionTask(
file_path=file_path,
@ -44,6 +56,8 @@ class TestOpenAIWhisperAPIFileTranscriber:
transcriber.completed.connect(mock_completed)
transcriber.run()
mock_openai_client.return_value.audio.transcriptions.create.assert_called()
called_segments = mock_completed.call_args[0][0]
assert len(called_segments) == 1