-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDeepgramAudioStreamer.cpp
More file actions
610 lines (499 loc) · 20.3 KB
/
DeepgramAudioStreamer.cpp
File metadata and controls
610 lines (499 loc) · 20.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
#include <windows.h>
#include <mmdeviceapi.h>
#include <audioclient.h>
#include <avrt.h>
#include <iostream>
#include <vector>
#include <string>
#include <thread>
#include <chrono>
#include <cstdint>
#include <atomic>
// WebSocket libraries - using Beast (part of Boost)
#include <boost/beast/core.hpp>
#include <boost/beast/websocket.hpp>
#include <boost/beast/ssl.hpp>
#include <boost/asio/connect.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/ssl/stream.hpp>
// For base64 encoding of API key
#include <boost/beast/core/detail/base64.hpp>
namespace beast = boost::beast;
namespace http = beast::http;
namespace websocket = beast::websocket;
namespace net = boost::asio;
namespace ssl = boost::asio::ssl;
using tcp = boost::asio::ip::tcp;
// Configuration parameters
struct DeepgramConfig {
std::string apiKey;
std::string model = "nova-3";
std::string language = "en-US";
bool interimResults = true;
bool punctuate = true;
int sampleRate = 16000;
int encoding = 16; // 16-bit PCM
int channels = 1; // Mono
};
// Audio resampler for converting to mono 16-bit PCM at 16kHz
class AudioResampler {
public:
// Convert audio format to match Deepgram requirements
static std::vector<int16_t> ConvertAudio(const BYTE* data, UINT32 frames, WAVEFORMATEX* format) {
std::vector<int16_t> result;
// Determine if this is a float format
bool isFloat = (format->wFormatTag == 3); // WAVE_FORMAT_IEEE_FLOAT
// For extensible format, check the SubFormat
if (format->wFormatTag == WAVE_FORMAT_EXTENSIBLE && format->cbSize >= 22) {
WAVEFORMATEXTENSIBLE* extFormat = reinterpret_cast<WAVEFORMATEXTENSIBLE*>(format);
isFloat = (extFormat->SubFormat.Data1 == 3); // Check first DWORD of GUID
}
// Calculate total samples
UINT32 totalSamples = frames * format->nChannels;
result.reserve(frames); // Reserve space for mono output
// Process based on format
if (isFloat && format->wBitsPerSample == 32) {
// Convert 32-bit float to 16-bit int and downsample to mono
const float* floatData = reinterpret_cast<const float*>(data);
// Process all frames
for (UINT32 frame = 0; frame < frames; frame++) {
float sampleSum = 0.0f;
// Mix all channels to mono
for (UINT16 ch = 0; ch < format->nChannels; ch++) {
sampleSum += floatData[frame * format->nChannels + ch];
}
// Average the channels
float monoSample = sampleSum / format->nChannels;
// Clamp the value
if (monoSample > 1.0f) monoSample = 1.0f;
if (monoSample < -1.0f) monoSample = -1.0f;
// Convert to int16
int16_t sample = static_cast<int16_t>(monoSample * 32767.0f);
result.push_back(sample);
}
}
else if (format->wBitsPerSample == 16) {
// Already 16-bit PCM, just need to downsample to mono
const int16_t* pcmData = reinterpret_cast<const int16_t*>(data);
// Process all frames
for (UINT32 frame = 0; frame < frames; frame++) {
int32_t sampleSum = 0;
// Mix all channels to mono
for (UINT16 ch = 0; ch < format->nChannels; ch++) {
sampleSum += pcmData[frame * format->nChannels + ch];
}
// Average the channels
int16_t monoSample = static_cast<int16_t>(sampleSum / format->nChannels);
result.push_back(monoSample);
}
}
else if (format->wBitsPerSample == 32 && !isFloat) {
// 32-bit integer to 16-bit
const int32_t* intData = reinterpret_cast<const int32_t*>(data);
// Process all frames
for (UINT32 frame = 0; frame < frames; frame++) {
int64_t sampleSum = 0;
// Mix all channels to mono
for (UINT16 ch = 0; ch < format->nChannels; ch++) {
sampleSum += intData[frame * format->nChannels + ch];
}
// Average and convert to 16-bit
int16_t monoSample = static_cast<int16_t>(sampleSum / format->nChannels >> 16);
result.push_back(monoSample);
}
}
else if (format->wBitsPerSample == 24) {
// 24-bit to 16-bit conversion with mono downmixing
// Process all frames
for (UINT32 frame = 0; frame < frames; frame++) {
int32_t sampleSum = 0;
// Mix all channels to mono
for (UINT16 ch = 0; ch < format->nChannels; ch++) {
// Extract 24-bit sample (3 bytes) and convert to 32-bit
size_t sampleOffset = (frame * format->nChannels + ch) * 3; // 3 bytes per sample
int32_t sample = (data[sampleOffset] << 8) | (data[sampleOffset + 1] << 16) | (data[sampleOffset + 2] << 24);
sampleSum += sample;
}
// Average and convert to 16-bit
int16_t monoSample = static_cast<int16_t>(sampleSum / format->nChannels >> 8);
result.push_back(monoSample);
}
}
else if (format->wBitsPerSample == 8) {
// 8-bit to 16-bit conversion with mono downmixing
// Process all frames
for (UINT32 frame = 0; frame < frames; frame++) {
int32_t sampleSum = 0;
// Mix all channels to mono
for (UINT16 ch = 0; ch < format->nChannels; ch++) {
// Convert 8-bit unsigned (0-255) to signed 16-bit (-32768 to 32767)
int16_t sample = static_cast<int16_t>((static_cast<int>(data[frame * format->nChannels + ch]) - 128) << 8);
sampleSum += sample;
}
// Average channels
int16_t monoSample = static_cast<int16_t>(sampleSum / format->nChannels);
result.push_back(monoSample);
}
}
return result;
}
// Simple resampling if needed (for future implementation)
static std::vector<int16_t> Resample(const std::vector<int16_t>& input, int originalRate, int targetRate) {
// If rates match, return the original data
if (originalRate == targetRate) {
return input;
}
// Simple linear interpolation resampling
// This is a basic implementation - for production, consider a better resampling algorithm
double ratio = static_cast<double>(originalRate) / targetRate;
size_t outputSize = static_cast<size_t>(input.size() / ratio);
std::vector<int16_t> output(outputSize);
for (size_t i = 0; i < outputSize; i++) {
double pos = i * ratio;
size_t idx = static_cast<size_t>(pos);
double frac = pos - idx;
if (idx + 1 < input.size()) {
output[i] = static_cast<int16_t>((1.0 - frac) * input[idx] + frac * input[idx + 1]);
} else {
output[i] = input[idx];
}
}
return output;
}
};
// Deepgram WebSocket client
class DeepgramClient {
private:
net::io_context ioc;
ssl::context ctx{ssl::context::tlsv12_client};
websocket::stream<ssl::stream<tcp::socket>> ws;
DeepgramConfig config;
std::thread receiveThread;
std::atomic<bool> isConnected{false};
std::atomic<bool> shouldStop{false};
public:
DeepgramClient(const DeepgramConfig& config)
: config(config), ws(ioc, ctx) {
// Set up SSL
ctx.set_default_verify_paths();
ctx.set_verify_mode(ssl::verify_peer);
}
~DeepgramClient() {
disconnect();
}
bool connect() {
try {
// Look up the domain name
auto const results = resolver.resolve("api.deepgram.com", "443");
// Make the connection on the IP address we get
net::connect(ws.next_layer().next_layer(), results);
// Perform the SSL handshake
ws.next_layer().handshake(ssl::stream_base::client);
// Set up WebSocket connection
ws.set_option(websocket::stream_base::decorator(
[this](websocket::request_type& req) {
req.set(http::field::authorization, "Token " + config.apiKey);
}));
// Build the WebSocket URL with query parameters
std::string target = "/v1/listen?";
target += "model=" + config.model;
target += "&language=" + config.language;
target += "&interim_results=" + std::string(config.interimResults ? "true" : "false");
target += "&punctuate=" + std::string(config.punctuate ? "true" : "false");
target += "&encoding=linear16";
target += "&sample_rate=" + std::to_string(config.sampleRate);
target += "&channels=" + std::to_string(config.channels);
// Perform the WebSocket handshake
ws.handshake("api.deepgram.com", target);
isConnected = true;
// Start the message receiving thread
receiveThread = std::thread([this]() {
receiveMessages();
});
return true;
}
catch(std::exception const& e) {
std::cerr << "Error connecting to Deepgram: " << e.what() << std::endl;
return false;
}
}
void disconnect() {
if (!isConnected) return;
shouldStop = true;
try {
// Close the WebSocket connection
ws.close(websocket::close_code::normal);
// Wait for the receive thread to finish
if (receiveThread.joinable()) {
receiveThread.join();
}
isConnected = false;
}
catch(std::exception const& e) {
std::cerr << "Error disconnecting from Deepgram: " << e.what() << std::endl;
}
}
bool sendAudio(const std::vector<int16_t>& audioData) {
if (!isConnected) return false;
try {
// Send the audio data as binary
ws.binary(true);
ws.write(net::buffer(audioData.data(), audioData.size() * sizeof(int16_t)));
return true;
}
catch(std::exception const& e) {
std::cerr << "Error sending audio to Deepgram: " << e.what() << std::endl;
return false;
}
}
private:
tcp::resolver resolver{ioc};
void receiveMessages() {
beast::flat_buffer buffer;
while (!shouldStop) {
try {
// Read a message into the buffer
ws.read(buffer);
// Parse the JSON response
std::string message(beast::buffers_to_string(buffer.data()));
buffer.consume(buffer.size());
// Print the transcription result
std::cout << "Deepgram transcription: " << message << std::endl;
}
catch(websocket::close_reason const& reason) {
std::cerr << "WebSocket closed: " << reason.reason << std::endl;
break;
}
catch(std::exception const& e) {
if (!shouldStop) {
std::cerr << "Error receiving message from Deepgram: " << e.what() << std::endl;
}
break;
}
}
}
};
// Main audio capture and streaming class
class AudioStreamer {
private:
IMMDeviceEnumerator* pEnumerator = nullptr;
IMMDevice* pDevice = nullptr;
IAudioClient* pAudioClient = nullptr;
IAudioCaptureClient* pCaptureClient = nullptr;
WAVEFORMATEX* pwfx = nullptr;
DeepgramClient* deepgramClient = nullptr;
DeepgramConfig config;
std::atomic<bool> isRunning{false};
std::thread captureThread;
public:
AudioStreamer(const DeepgramConfig& config) : config(config) {}
~AudioStreamer() {
stop();
cleanup();
}
bool initialize() {
HRESULT hr;
// Initialize COM
hr = CoInitialize(nullptr);
if (FAILED(hr)) {
std::cerr << "CoInitialize failed: " << std::hex << hr << std::endl;
return false;
}
// Create device enumerator
hr = CoCreateInstance(__uuidof(MMDeviceEnumerator), nullptr, CLSCTX_ALL,
__uuidof(IMMDeviceEnumerator), (void**)&pEnumerator);
if (FAILED(hr)) {
std::cerr << "Failed to create device enumerator: " << std::hex << hr << std::endl;
cleanup();
return false;
}
// Get default audio device
hr = pEnumerator->GetDefaultAudioEndpoint(eRender, eConsole, &pDevice);
if (FAILED(hr)) {
std::cerr << "Failed to get default audio device: " << std::hex << hr << std::endl;
cleanup();
return false;
}
// Activate audio client
hr = pDevice->Activate(__uuidof(IAudioClient), CLSCTX_ALL, nullptr, (void**)&pAudioClient);
if (FAILED(hr)) {
std::cerr << "Failed to activate audio client: " << std::hex << hr << std::endl;
cleanup();
return false;
}
// Get audio format
hr = pAudioClient->GetMixFormat(&pwfx);
if (FAILED(hr)) {
std::cerr << "Failed to get mix format: " << std::hex << hr << std::endl;
cleanup();
return false;
}
// Print audio format for debugging
std::cout << "Audio Format: " << std::endl;
std::cout << " Format Tag: " << pwfx->wFormatTag << std::endl;
std::cout << " Channels: " << pwfx->nChannels << std::endl;
std::cout << " Sample Rate: " << pwfx->nSamplesPerSec << std::endl;
std::cout << " Bits Per Sample: " << pwfx->wBitsPerSample << std::endl;
// Initialize audio client
hr = pAudioClient->Initialize(AUDCLNT_SHAREMODE_SHARED, AUDCLNT_STREAMFLAGS_LOOPBACK,
0, 0, pwfx, nullptr);
if (FAILED(hr)) {
std::cerr << "Failed to initialize audio client: " << std::hex << hr << std::endl;
cleanup();
return false;
}
// Get capture client
hr = pAudioClient->GetService(__uuidof(IAudioCaptureClient), (void**)&pCaptureClient);
if (FAILED(hr)) {
std::cerr << "Failed to get capture client: " << std::hex << hr << std::endl;
cleanup();
return false;
}
// Create and connect to Deepgram
deepgramClient = new DeepgramClient(config);
if (!deepgramClient->connect()) {
std::cerr << "Failed to connect to Deepgram" << std::endl;
cleanup();
return false;
}
return true;
}
bool start() {
if (isRunning) return true;
// Start audio capture
HRESULT hr = pAudioClient->Start();
if (FAILED(hr)) {
std::cerr << "Failed to start audio capture: " << std::hex << hr << std::endl;
return false;
}
isRunning = true;
// Start capture thread
captureThread = std::thread([this]() {
captureAndStream();
});
return true;
}
void stop() {
if (!isRunning) return;
isRunning = false;
// Wait for capture thread to complete
if (captureThread.joinable()) {
captureThread.join();
}
// Stop audio capture
if (pAudioClient) {
pAudioClient->Stop();
}
// Disconnect from Deepgram
if (deepgramClient) {
deepgramClient->disconnect();
}
}
private:
void captureAndStream() {
HRESULT hr;
UINT32 packetLength = 0;
BYTE* pData;
DWORD flags;
UINT32 numFrames;
std::cout << "Audio capture started. Streaming to Deepgram..." << std::endl;
while (isRunning) {
// Get next packet size
hr = pCaptureClient->GetNextPacketSize(&packetLength);
if (FAILED(hr)) {
std::cerr << "Failed to get next packet size: " << std::hex << hr << std::endl;
break;
}
if (packetLength == 0) {
// No data yet, wait a bit
std::this_thread::sleep_for(std::chrono::milliseconds(10));
continue;
}
// Get the captured data
hr = pCaptureClient->GetBuffer(&pData, &numFrames, &flags, nullptr, nullptr);
if (FAILED(hr)) {
std::cerr << "Failed to get buffer: " << std::hex << hr << std::endl;
break;
}
// Skip silent frames
if (!(flags & AUDCLNT_BUFFERFLAGS_SILENT)) {
// Convert audio to format expected by Deepgram (16-bit PCM, mono, 16kHz)
std::vector<int16_t> monoAudio = AudioResampler::ConvertAudio(pData, numFrames, pwfx);
// Resample if needed
if (pwfx->nSamplesPerSec != config.sampleRate) {
monoAudio = AudioResampler::Resample(monoAudio, pwfx->nSamplesPerSec, config.sampleRate);
}
// Send to Deepgram
if (!deepgramClient->sendAudio(monoAudio)) {
std::cerr << "Failed to send audio to Deepgram" << std::endl;
break;
}
}
// Release the buffer
hr = pCaptureClient->ReleaseBuffer(numFrames);
if (FAILED(hr)) {
std::cerr << "Failed to release buffer: " << std::hex << hr << std::endl;
break;
}
}
std::cout << "Audio capture stopped." << std::endl;
}
void cleanup() {
if (deepgramClient) {
delete deepgramClient;
deepgramClient = nullptr;
}
if (pCaptureClient) {
pCaptureClient->Release();
pCaptureClient = nullptr;
}
if (pAudioClient) {
pAudioClient->Release();
pAudioClient = nullptr;
}
if (pwfx) {
CoTaskMemFree(pwfx);
pwfx = nullptr;
}
if (pDevice) {
pDevice->Release();
pDevice = nullptr;
}
if (pEnumerator) {
pEnumerator->Release();
pEnumerator = nullptr;
}
CoUninitialize();
}
};
int main(int argc, char* argv[]) {
// Configuration
DeepgramConfig config;
if (argc < 2) {
std::cerr << "Usage: " << argv[0] << " <Deepgram API Key>" << std::endl;
std::cerr << "Example: " << argv[0] << " YOUR_DEEPGRAM_API_KEY" << std::endl;
return 1;
}
// Get API key from command line
config.apiKey = argv[1];
// Create audio streamer
AudioStreamer streamer(config);
// Initialize
if (!streamer.initialize()) {
std::cerr << "Failed to initialize audio streamer" << std::endl;
return 1;
}
// Start streaming
if (!streamer.start()) {
std::cerr << "Failed to start audio streaming" << std::endl;
return 1;
}
std::cout << "Streaming desktop audio to Deepgram for transcription..." << std::endl;
std::cout << "Press Enter to stop." << std::endl;
// Wait for user to press Enter
std::cin.get();
// Stop streaming
streamer.stop();
std::cout << "Audio streaming stopped." << std::endl;
return 0;
}