Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public static RoomConfig of(AudioCodec codec) {
}

@JsonProperty("room_mode")
@Builder.Default
private String roomMode = "";

@JsonProperty("translate_config")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,18 @@ public class AsrConfig {
private String context;

@JsonProperty("user_language")
private String userLanguage;
@Builder.Default
private String userLanguage = "common";

@JsonProperty("enable_ddc")
@Builder.Default
private Boolean enableDdc = true;

@JsonProperty("enable_itn")
@Builder.Default
private Boolean enableItn = true;

@JsonProperty("enable_punc")
@Builder.Default
private Boolean enablePunc = true;
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,11 @@
public class TranscriptionsUpdateEventData {
@JsonProperty("input_audio")
private InputAudio inputAudio;

@JsonProperty("asr_config")
private AsrConfig asrConfig;

public TranscriptionsUpdateEventData(InputAudio inputAudio) {
this.inputAudio = inputAudio;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;

import java.util.Arrays;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
Expand All @@ -15,6 +17,7 @@

import com.coze.openapi.client.websocket.event.EventType;
import com.coze.openapi.client.websocket.event.downstream.*;
import com.coze.openapi.client.websocket.event.model.AsrConfig;
import com.coze.openapi.client.websocket.event.model.InputAudio;
import com.coze.openapi.client.websocket.event.model.TranscriptionsUpdateEventData;

Expand Down Expand Up @@ -117,6 +120,68 @@ public void testHandleTranscriptionsUpdatedEvent() {
assertEquals(1, event.getData().getInputAudio().getChannel());
assertEquals(16, event.getData().getInputAudio().getBitDepth());

// 验证 asr_config 为 null(向后兼容)
assertNull(event.getData().getAsrConfig());

// 验证 detail
assertEquals("20241210152726467C48D89D6DB2F3***", event.getDetail().getLogID());
}

@Test
public void testHandleTranscriptionsUpdatedEventWithAsrConfig() {
String json =
"{\n"
+ " \"id\": \"event_id\",\n"
+ " \"event_type\": \"transcriptions.updated\",\n"
+ " \"data\": {\n"
+ " \"input_audio\": {\n"
+ " \"format\": \"pcm\",\n"
+ " \"codec\": \"pcm\",\n"
+ " \"sample_rate\": 24000,\n"
+ " \"channel\": 1,\n"
+ " \"bit_depth\": 16\n"
+ " },\n"
+ " \"asr_config\": {\n"
+ " \"hot_words\": [\"Coze\", \"AI\"],\n"
+ " \"context\": \"Coze AI\",\n"
+ " \"user_language\": \"en-US\",\n"
+ " \"enable_ddc\": true,\n"
+ " \"enable_itn\": true,\n"
+ " \"enable_punc\": true\n"
+ " }\n"
+ " },\n"
+ " \"detail\": {\n"
+ " \"logid\": \"20241210152726467C48D89D6DB2F3***\"\n"
+ " }\n"
+ "}\n";

client.handleEvent(mockWebSocket, json);

verify(mockCallbackHandler)
.onTranscriptionsUpdated(eq(client), transcriptionsUpdatedEventCaptor.capture());

TranscriptionsUpdatedEvent event = transcriptionsUpdatedEventCaptor.getValue();
assertEquals(EventType.TRANSCRIPTIONS_UPDATED, event.getEventType());
assertEquals("event_id", event.getId());

// 验证 data
assertEquals("pcm", event.getData().getInputAudio().getFormat());
assertEquals("pcm", event.getData().getInputAudio().getCodec());
assertEquals(24000, event.getData().getInputAudio().getSampleRate());
assertEquals(1, event.getData().getInputAudio().getChannel());
assertEquals(16, event.getData().getInputAudio().getBitDepth());

// 验证 asr_config
assertNotNull(event.getData().getAsrConfig());
assertEquals("en-US", event.getData().getAsrConfig().getUserLanguage());
assertEquals("Coze AI", event.getData().getAsrConfig().getContext());
assertTrue(event.getData().getAsrConfig().getEnableDdc());
assertTrue(event.getData().getAsrConfig().getEnableItn());
assertTrue(event.getData().getAsrConfig().getEnablePunc());
assertEquals(2, event.getData().getAsrConfig().getHotWords().size());
assertTrue(event.getData().getAsrConfig().getHotWords().contains("Coze"));
assertTrue(event.getData().getAsrConfig().getHotWords().contains("AI"));

// 验证 detail
assertEquals("20241210152726467C48D89D6DB2F3***", event.getDetail().getLogID());
}
Expand Down Expand Up @@ -287,13 +352,36 @@ void testTranscriptionsUpdate() {
.channel(1)
.bitDepth(16)
.build())
.asrConfig(
AsrConfig.builder()
.hotWords(Arrays.asList("Coze", "AI"))
.context("Real-time transcription")
.userLanguage("en-US")
.build())
.build();

client.transcriptionsUpdate(data);

verify(mockWebSocket).send(anyString()); // 验证发送了消息
}

@Test
void testTranscriptionsUpdateWithoutAsrConfig() {
TranscriptionsUpdateEventData data =
new TranscriptionsUpdateEventData(
InputAudio.builder()
.format("pcm")
.codec("pcm")
.sampleRate(24000)
.channel(1)
.bitDepth(16)
.build());

client.transcriptionsUpdate(data);

verify(mockWebSocket).send(anyString()); // 验证发送了消息
}

@Test
void testInputAudioBufferAppendWithString() {
String audioData = "base64EncodedAudioData";
Expand Down
2 changes: 2 additions & 0 deletions api/src/test/java/com/coze/openapi/utils/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ public class Utils {
private static final Headers commonHeader =
Headers.of(
new HashMap<String, String>() {
private static final long serialVersionUID = 1L;

{
put(LOG_HEADER, TEST_LOG_ID);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.concurrent.TimeUnit;

import com.coze.openapi.client.audio.common.AudioFormat;
import com.coze.openapi.client.audio.speech.CreateSpeechReq;
import com.coze.openapi.client.audio.speech.CreateSpeechResp;
import com.coze.openapi.client.websocket.event.downstream.*;
import com.coze.openapi.client.websocket.event.model.AsrConfig;
import com.coze.openapi.client.websocket.event.model.InputAudio;
import com.coze.openapi.client.websocket.event.model.TranscriptionsUpdateEventData;
import com.coze.openapi.service.auth.TokenAuth;
Expand All @@ -27,7 +27,6 @@ public class WebsocketTranscriptionsExample {
public static boolean isDone = false;

private static class CallbackHandler extends WebsocketsAudioTranscriptionsCallbackHandler {
private final ByteBuffer buffer = ByteBuffer.allocate(1024 * 1024 * 10); // 分配 10MB 缓冲区

public CallbackHandler() {
super();
Expand Down Expand Up @@ -120,7 +119,15 @@ public static void main(String[] args) throws Exception {

InputAudio inputAudio =
InputAudio.builder().sampleRate(24000).codec("pcm").format("wav").channel(2).build();
client.transcriptionsUpdate(new TranscriptionsUpdateEventData(inputAudio));

AsrConfig asrConfig =
AsrConfig.builder()
.hotWords(Arrays.asList("Coze", "AI"))
.context("Real-time transcription")
.userLanguage("en-US")
.build();

client.transcriptionsUpdate(new TranscriptionsUpdateEventData(inputAudio, asrConfig));

try (InputStream inputStream = speechResp.getResponse().byteStream()) {
byte[] buffer = new byte[1024];
Expand Down
Loading