ONNX Runtime generate() Java API

Note: this API is in preview and is subject to change.

Overview

This document describes the Java API for ONNX Runtime GenAI.
Below are the main classes and methods, with code snippets and descriptions for each.


Install and import

The Java API is delivered by the ai.onnxruntime.genai Java package. Package publication is pending. To build the package from source, see the build from source guide.

import ai.onnxruntime.genai.*; 

Model class

Constructor

Initializes a new model from the given model path.

public Model(String modelPath) throws GenAIException 

createGeneratorParams

Creates a GeneratorParams instance for executing the model.

public GeneratorParams createGeneratorParams() throws GenAIException 

createTokenizer

Creates a Tokenizer instance for this model.

public Tokenizer createTokenizer() throws GenAIException 

generate

Generates output sequences using the provided generator parameters.

public Sequences generate(GeneratorParams generatorParams) throws GenAIException 

Config class

Constructor

Initializes a new configuration object from a config path.

public Config(String configPath) throws GenAIException 

clearProviders

Clears all providers from the configuration.

public void clearProviders() throws GenAIException 

appendProvider

Appends a provider to the configuration.

public void appendProvider(String provider) throws GenAIException 

setProviderOption

Sets a provider option in the configuration.

public void setProviderOption(String provider, String name, String value) throws GenAIException 

overlay

Overlays a JSON string onto the configuration.

public void overlay(String json) throws GenAIException 

Tokenizer class

Constructor

Initializes a tokenizer for the given model.

public Tokenizer(Model model) throws GenAIException 

encode

Encodes a string into a sequence of token ids.

public Sequences encode(String string) throws GenAIException 

encodeBatch

Encodes an array of strings into a sequence of token ids for each input.

public Sequences encodeBatch(String[] strings) throws GenAIException 

decode

Decodes a sequence of token ids into text.

public String decode(int[] sequence) throws GenAIException 

decodeBatch

Decodes a batch of sequences of token ids into text.

public String[] decodeBatch(Sequences sequences) throws GenAIException 

createStream

Creates a TokenizerStream object for streaming tokenization.

public TokenizerStream createStream() throws GenAIException 

TokenizerStream class

decode

Decodes a single token in the stream and returns the generated string chunk.

public String decode(int token) throws GenAIException 

GeneratorParams class

Constructor

Initializes generator parameters for the given model.

public GeneratorParams(Model model) throws GenAIException 

setSearchOption (double)

Sets a numeric search option.

public void setSearchOption(String optionName, double value) throws GenAIException 

setSearchOption (boolean)

Sets a boolean search option.

public void setSearchOption(String optionName, boolean value) throws GenAIException 

setInput (Sequences)

Sets the prompt(s) for model execution using sequences.

public void setInput(Sequences sequences) throws GenAIException 

setInput (int[])

Sets the prompt(s) token ids for model execution.

public void setInput(int[] tokenIds, int sequenceLength, int batchSize) throws GenAIException 

Generator class

Constructor

Constructs a Generator object with the given model and generator parameters.

public Generator(Model model, GeneratorParams generatorParams) throws GenAIException 

isDone

Checks if the generation process is done.

public boolean isDone() 

computeLogits

Computes the logits for the next token in the sequence.

public void computeLogits() throws GenAIException 

generateNextToken

Generates the next token in the sequence.

public void generateNextToken() throws GenAIException 

getSequence

Retrieves a sequence of token ids for the specified sequence index.

public int[] getSequence(long sequenceIndex) throws GenAIException 

getLastTokenInSequence

Retrieves the last token in the sequence for the specified sequence index.

public int getLastTokenInSequence(long sequenceIndex) throws GenAIException 

Sequences class

numSequences

Gets the number of sequences in the collection.

public long numSequences() 

getSequence

Gets the sequence at the specified index.

public int[] getSequence(long sequenceIndex) 

Tensor class

Constructor

Constructs a Tensor with the given data, shape, and element type.

public Tensor(ByteBuffer data, long[] shape, ElementType elementType) throws GenAIException 

Result class

isSuccess

Indicates if the operation was successful.

public boolean isSuccess() 

getError

Gets the error message from a failed operation.

public String getError() 

Utils class

setLogBool

Sets a boolean logging option.

public static void setLogBool(String name, boolean value) 

setLogString

Sets a string logging option.

public static void setLogString(String name, String value) 

setCurrentGpuDeviceId

Sets the current GPU device ID.

public static void setCurrentGpuDeviceId(int deviceId) 

getCurrentGpuDeviceId

Gets the current GPU device ID.

public static int getCurrentGpuDeviceId()