mirror of
https://github.com/chidiwilliams/buzz.git
synced 2024-06-13 13:42:11 +02:00
fix: openai transcription model response
This commit is contained in:
parent
43c3f66aa7
commit
397dadd7a2
|
@ -103,7 +103,7 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
|||
response_format="verbose_json",
|
||||
language=self.transcription_task.transcription_options.language,
|
||||
)
|
||||
if self.transcription_task.transcription_options.task == Task.TRANSLATE
|
||||
if self.transcription_task.transcription_options.task == Task.TRANSCRIBE
|
||||
else self.openai_client.audio.translations.create(
|
||||
model="whisper-1", file=file, response_format="verbose_json"
|
||||
)
|
||||
|
@ -115,7 +115,7 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
|||
int(segment["end"] * 1000 + offset_ms),
|
||||
segment["text"],
|
||||
)
|
||||
for segment in transcript["segments"]
|
||||
for segment in transcript.model_extra["segments"]
|
||||
]
|
||||
|
||||
def stop(self):
|
||||
|
|
|
@ -12,6 +12,8 @@ from buzz.transcriber.transcriber import (
|
|||
FileTranscriptionOptions,
|
||||
)
|
||||
|
||||
from openai.types.audio import Transcription, Translation
|
||||
|
||||
|
||||
class TestOpenAIWhisperAPIFileTranscriber:
|
||||
@pytest.fixture
|
||||
|
@ -19,13 +21,23 @@ class TestOpenAIWhisperAPIFileTranscriber:
|
|||
with patch(
|
||||
"buzz.transcriber.openai_whisper_api_file_transcriber.OpenAI"
|
||||
) as mock:
|
||||
return_value = {"segments": [{"start": 0, "end": 6.56, "text": "Hello"}]}
|
||||
mock.return_value.audio.transcriptions.create.return_value = return_value
|
||||
mock.return_value.audio.translations.create.return_value = return_value
|
||||
return_value = {
|
||||
"text": "",
|
||||
"segments": [{"start": 0, "end": 6.56, "text": "Hello"}],
|
||||
}
|
||||
mock.return_value.audio.transcriptions.create.return_value = Transcription(
|
||||
**return_value
|
||||
)
|
||||
mock.return_value.audio.translations.create.return_value = Translation(
|
||||
**return_value
|
||||
)
|
||||
yield mock
|
||||
|
||||
def test_transcribe(self, mock_openai_client):
|
||||
file_path = "testdata/whisper-french.mp3"
|
||||
file_path = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
"../../testdata/whisper-french.mp3",
|
||||
)
|
||||
transcriber = OpenAIWhisperAPIFileTranscriber(
|
||||
task=FileTranscriptionTask(
|
||||
file_path=file_path,
|
||||
|
@ -44,6 +56,8 @@ class TestOpenAIWhisperAPIFileTranscriber:
|
|||
transcriber.completed.connect(mock_completed)
|
||||
transcriber.run()
|
||||
|
||||
mock_openai_client.return_value.audio.transcriptions.create.assert_called()
|
||||
|
||||
called_segments = mock_completed.call_args[0][0]
|
||||
|
||||
assert len(called_segments) == 1
|
||||
|
|
Loading…
Reference in a new issue