diff --git a/sdk/cs/README.md b/sdk/cs/README.md index ad6f477a..26287217 100644 --- a/sdk/cs/README.md +++ b/sdk/cs/README.md @@ -181,11 +181,11 @@ var loaded = await catalog.GetLoadedModelsAsync(); ### Model Lifecycle -Each `Model` wraps one or more `ModelVariant` entries (different quantizations, hardware targets). The SDK auto-selects the best variant, or you can pick one: +Each model may have multiple variants (different quantizations, hardware targets). The SDK auto-selects the best variant, or you can pick one. All models implement the `IModel` interface. ```csharp // Check and select variants -Console.WriteLine($"Selected: {model.SelectedVariant.Id}"); +Console.WriteLine($"Selected: {model.Id}"); foreach (var v in model.Variants) Console.WriteLine($" {v.Id} (cached: {await v.IsCachedAsync()})"); @@ -389,8 +389,8 @@ Key types: | [`FoundryLocalManager`](./docs/api/microsoft.ai.foundry.local.foundrylocalmanager.md) | Singleton entry point — create, catalog, web service | | [`Configuration`](./docs/api/microsoft.ai.foundry.local.configuration.md) | Initialization settings | | [`ICatalog`](./docs/api/microsoft.ai.foundry.local.icatalog.md) | Model catalog interface | -| [`Model`](./docs/api/microsoft.ai.foundry.local.model.md) | Model with variant selection | -| [`ModelVariant`](./docs/api/microsoft.ai.foundry.local.modelvariant.md) | Specific model variant (hardware/quantization) | +| [`IModel`](./docs/api/microsoft.ai.foundry.local.imodel.md) | Model interface — identity, metadata, lifecycle, variant selection | +| [`Model`](./docs/api/microsoft.ai.foundry.local.model.md) | Model with variant selection (implements `IModel`) | | [`OpenAIChatClient`](./docs/api/microsoft.ai.foundry.local.openaichatclient.md) | Chat completions (sync + streaming) | | [`OpenAIAudioClient`](./docs/api/microsoft.ai.foundry.local.openaiaudioclient.md) | Audio transcription (sync + streaming) | | [`LiveAudioTranscriptionSession`](./docs/api/microsoft.ai.foundry.local.openai.liveaudiotranscriptionsession.md) | Real-time audio streaming session | diff --git a/sdk/js/README.md b/sdk/js/README.md index 9e56ec52..5590ab12 100644 --- a/sdk/js/README.md +++ b/sdk/js/README.md @@ -148,7 +148,7 @@ const loaded = await catalog.getLoadedModels(); ### Loading and Running Models -Each `Model` can have multiple variants (different quantizations or formats). The SDK automatically selects the best available variant, preferring cached versions. +Each model can have multiple variants (different quantizations or formats). The SDK automatically selects the best available variant, preferring cached versions. All models implement the `IModel` interface. ```typescript const model = await catalog.getModel('qwen2.5-0.5b'); @@ -259,8 +259,7 @@ Auto-generated class documentation lives in [`docs/classes/`](docs/classes/): - [FoundryLocalManager](docs/classes/FoundryLocalManager.md) — SDK entry point, web service management - [Catalog](docs/classes/Catalog.md) — Model discovery and browsing -- [Model](docs/classes/Model.md) — High-level model with variant selection -- [ModelVariant](docs/classes/ModelVariant.md) — Specific model variant: download, load, inference +- [IModel](docs/README.md#imodel) — Model interface: variant selection, download, load, inference - [ChatClient](docs/classes/ChatClient.md) — Chat completions (sync and streaming) - [AudioClient](docs/classes/AudioClient.md) — Audio transcription (sync and streaming) - [ModelLoadManager](docs/classes/ModelLoadManager.md) — Low-level model loading management diff --git a/sdk/js/docs/README.md b/sdk/js/docs/README.md index 0cb39e1b..b0167b4d 100644 --- a/sdk/js/docs/README.md +++ b/sdk/js/docs/README.md @@ -23,7 +23,6 @@ - [FoundryLocalManager](classes/FoundryLocalManager.md) - [Model](classes/Model.md) - [ModelLoadManager](classes/ModelLoadManager.md) -- [ModelVariant](classes/ModelVariant.md) - [ResponsesClient](classes/ResponsesClient.md) - [ResponsesClientSettings](classes/ResponsesClientSettings.md) @@ -562,6 +561,18 @@ get id(): string; `string` +##### info + +###### Get Signature + +```ts +get info(): ModelInfo; +``` + +###### Returns + +[`ModelInfo`](#modelinfo) + ##### inputModalities ###### Get Signature @@ -622,6 +633,20 @@ get supportsToolCalling(): boolean | null; `boolean` \| `null` +##### variants + +###### Get Signature + +```ts +get variants(): IModel[]; +``` + +Variants of the model that are available. Variants of the model are optimized for different devices. + +###### Returns + +[`IModel`](#imodel)[] + #### Methods ##### createAudioClient() @@ -710,6 +735,29 @@ removeFromCache(): void; `void` +##### selectVariant() + +```ts +selectVariant(variant): void; +``` + +Select a model variant from variants to use for IModel operations. +An IModel from `variants` can also be used directly. + +###### Parameters + +| Parameter | Type | Description | +| ------ | ------ | ------ | +| `variant` | [`IModel`](#imodel) | Model variant to select. Must be one of the variants in `variants`. | + +###### Returns + +`void` + +###### Throws + +Error if variant is not valid for this model. + ##### unload() ```ts diff --git a/sdk/js/docs/classes/Catalog.md b/sdk/js/docs/classes/Catalog.md index 23f7cff3..78ce821c 100644 --- a/sdk/js/docs/classes/Catalog.md +++ b/sdk/js/docs/classes/Catalog.md @@ -47,7 +47,7 @@ The name of the catalog. ### getCachedModels() ```ts -getCachedModels(): Promise; +getCachedModels(): Promise; ``` Retrieves a list of all locally cached model variants. @@ -55,16 +55,39 @@ This method is asynchronous as it may involve file I/O or querying the underlyin #### Returns -`Promise`\<[`ModelVariant`](ModelVariant.md)[]\> +`Promise`\<[`IModel`](../README.md#imodel)[]\> -A Promise that resolves to an array of cached ModelVariant objects. +A Promise that resolves to an array of cached IModel objects. + +*** + +### getLatestVersion() + +```ts +getLatestVersion(modelOrModelVariant): Promise; +``` + +Get the latest version of a model. +This is used to check if a newer version of a model is available in the catalog for download. + +#### Parameters + +| Parameter | Type | Description | +| ------ | ------ | ------ | +| `modelOrModelVariant` | [`IModel`](../README.md#imodel) | The model to check for the latest version. | + +#### Returns + +`Promise`\<[`IModel`](../README.md#imodel)\> + +The latest version of the model. Will match the input if it is the latest version. *** ### getLoadedModels() ```ts -getLoadedModels(): Promise; +getLoadedModels(): Promise; ``` Retrieves a list of all currently loaded model variants. @@ -73,16 +96,16 @@ the underlying core or an external service, which can be an I/O bound operation. #### Returns -`Promise`\<[`ModelVariant`](ModelVariant.md)[]\> +`Promise`\<[`IModel`](../README.md#imodel)[]\> -A Promise that resolves to an array of loaded ModelVariant objects. +A Promise that resolves to an array of loaded IModel objects. *** ### getModel() ```ts -getModel(alias): Promise; +getModel(alias): Promise; ``` Retrieves a model by its alias. @@ -96,9 +119,9 @@ This method is asynchronous as it may ensure the catalog is up-to-date by fetchi #### Returns -`Promise`\<[`Model`](Model.md)\> +`Promise`\<[`IModel`](../README.md#imodel)\> -A Promise that resolves to the Model object if found, otherwise throws an error. +A Promise that resolves to the IModel object if found, otherwise throws an error. #### Throws @@ -109,7 +132,7 @@ Error - If alias is null, undefined, or empty. ### getModels() ```ts -getModels(): Promise; +getModels(): Promise; ``` Lists all available models in the catalog. @@ -117,19 +140,21 @@ This method is asynchronous as it may fetch the model list from a remote service #### Returns -`Promise`\<[`Model`](Model.md)[]\> +`Promise`\<[`IModel`](../README.md#imodel)[]\> -A Promise that resolves to an array of Model objects. +A Promise that resolves to an array of IModel objects. *** ### getModelVariant() ```ts -getModelVariant(modelId): Promise; +getModelVariant(modelId): Promise; ``` Retrieves a specific model variant by its ID. +NOTE: This will return an IModel with a single variant. Use getModel to get an IModel with all available +variants. This method is asynchronous as it may ensure the catalog is up-to-date by fetching from a remote service. #### Parameters @@ -140,9 +165,9 @@ This method is asynchronous as it may ensure the catalog is up-to-date by fetchi #### Returns -`Promise`\<[`ModelVariant`](ModelVariant.md)\> +`Promise`\<[`IModel`](../README.md#imodel)\> -A Promise that resolves to the ModelVariant object if found, otherwise throws an error. +A Promise that resolves to the IModel object if found, otherwise throws an error. #### Throws diff --git a/sdk/js/docs/classes/Model.md b/sdk/js/docs/classes/Model.md index 0b2dcfa6..f678f873 100644 --- a/sdk/js/docs/classes/Model.md +++ b/sdk/js/docs/classes/Model.md @@ -21,7 +21,7 @@ new Model(variant): Model; | Parameter | Type | | ------ | ------ | -| `variant` | [`ModelVariant`](ModelVariant.md) | +| `variant` | `ModelVariant` | #### Returns @@ -109,6 +109,28 @@ The ID of the selected variant. *** +### info + +#### Get Signature + +```ts +get info(): ModelInfo; +``` + +Gets the ModelInfo of the currently selected variant. + +##### Returns + +[`ModelInfo`](../README.md#modelinfo) + +The ModelInfo object. + +#### Implementation of + +[`IModel`](../README.md#imodel).[`info`](../README.md#info) + +*** + ### inputModalities #### Get Signature @@ -212,43 +234,22 @@ get supportsToolCalling(): boolean | null; #### Get Signature ```ts -get variants(): ModelVariant[]; +get variants(): IModel[]; ``` Gets all available variants for this model. ##### Returns -[`ModelVariant`](ModelVariant.md)[] - -An array of ModelVariant objects. +[`IModel`](../README.md#imodel)[] -## Methods +An array of IModel objects. -### addVariant() - -```ts -addVariant(variant): void; -``` - -Adds a new variant to this model. -Automatically selects the new variant if it is cached and the current one is not. - -#### Parameters - -| Parameter | Type | Description | -| ------ | ------ | ------ | -| `variant` | [`ModelVariant`](ModelVariant.md) | The model variant to add. | - -#### Returns - -`void` - -#### Throws +#### Implementation of -Error - If the argument is not a ModelVariant object, or if the variant's alias does not match the model's alias. +[`IModel`](../README.md#imodel).[`variants`](../README.md#variants) -*** +## Methods ### createAudioClient() @@ -410,7 +411,7 @@ Selects a specific variant. | Parameter | Type | Description | | ------ | ------ | ------ | -| `variant` | [`ModelVariant`](ModelVariant.md) | The model variant to select. | +| `variant` | [`IModel`](../README.md#imodel) | The model variant to select. Must be one of the variants in `variants`. | #### Returns @@ -418,7 +419,11 @@ Selects a specific variant. #### Throws -Error - If the argument is not a ModelVariant object, or if the variant does not belong to this model. +Error - If the variant does not belong to this model. + +#### Implementation of + +[`IModel`](../README.md#imodel).[`selectVariant`](../README.md#selectvariant) *** diff --git a/sdk/js/docs/classes/ModelVariant.md b/sdk/js/docs/classes/ModelVariant.md deleted file mode 100644 index 6f4e5ee8..00000000 --- a/sdk/js/docs/classes/ModelVariant.md +++ /dev/null @@ -1,397 +0,0 @@ -[foundry-local-sdk](../README.md) / ModelVariant - -# Class: ModelVariant - -Represents a specific variant of a model (e.g., a specific quantization or format). -Contains the low-level implementation for interacting with the model. - -## Implements - -- [`IModel`](../README.md#imodel) - -## Constructors - -### Constructor - -```ts -new ModelVariant( - modelInfo, - coreInterop, - modelLoadManager): ModelVariant; -``` - -#### Parameters - -| Parameter | Type | -| ------ | ------ | -| `modelInfo` | [`ModelInfo`](../README.md#modelinfo) | -| `coreInterop` | `CoreInterop` | -| `modelLoadManager` | [`ModelLoadManager`](ModelLoadManager.md) | - -#### Returns - -`ModelVariant` - -## Accessors - -### alias - -#### Get Signature - -```ts -get alias(): string; -``` - -Gets the alias of the model. - -##### Returns - -`string` - -The model alias. - -#### Implementation of - -[`IModel`](../README.md#imodel).[`alias`](../README.md#alias) - -*** - -### capabilities - -#### Get Signature - -```ts -get capabilities(): string | null; -``` - -##### Returns - -`string` \| `null` - -#### Implementation of - -[`IModel`](../README.md#imodel).[`capabilities`](../README.md#capabilities) - -*** - -### contextLength - -#### Get Signature - -```ts -get contextLength(): number | null; -``` - -##### Returns - -`number` \| `null` - -#### Implementation of - -[`IModel`](../README.md#imodel).[`contextLength`](../README.md#contextlength) - -*** - -### id - -#### Get Signature - -```ts -get id(): string; -``` - -Gets the unique identifier of the model variant. - -##### Returns - -`string` - -The model ID. - -#### Implementation of - -[`IModel`](../README.md#imodel).[`id`](../README.md#id-3) - -*** - -### inputModalities - -#### Get Signature - -```ts -get inputModalities(): string | null; -``` - -##### Returns - -`string` \| `null` - -#### Implementation of - -[`IModel`](../README.md#imodel).[`inputModalities`](../README.md#inputmodalities) - -*** - -### isCached - -#### Get Signature - -```ts -get isCached(): boolean; -``` - -Checks if the model variant is cached locally. - -##### Returns - -`boolean` - -True if cached, false otherwise. - -#### Implementation of - -[`IModel`](../README.md#imodel).[`isCached`](../README.md#iscached) - -*** - -### modelInfo - -#### Get Signature - -```ts -get modelInfo(): ModelInfo; -``` - -Gets the detailed information about the model variant. - -##### Returns - -[`ModelInfo`](../README.md#modelinfo) - -The ModelInfo object. - -*** - -### outputModalities - -#### Get Signature - -```ts -get outputModalities(): string | null; -``` - -##### Returns - -`string` \| `null` - -#### Implementation of - -[`IModel`](../README.md#imodel).[`outputModalities`](../README.md#outputmodalities) - -*** - -### path - -#### Get Signature - -```ts -get path(): string; -``` - -Gets the local file path of the model variant. - -##### Returns - -`string` - -The local file path. - -#### Implementation of - -[`IModel`](../README.md#imodel).[`path`](../README.md#path) - -*** - -### supportsToolCalling - -#### Get Signature - -```ts -get supportsToolCalling(): boolean | null; -``` - -##### Returns - -`boolean` \| `null` - -#### Implementation of - -[`IModel`](../README.md#imodel).[`supportsToolCalling`](../README.md#supportstoolcalling) - -## Methods - -### createAudioClient() - -```ts -createAudioClient(): AudioClient; -``` - -Creates an AudioClient for interacting with the model via audio operations. - -#### Returns - -[`AudioClient`](AudioClient.md) - -An AudioClient instance. - -#### Implementation of - -[`IModel`](../README.md#imodel).[`createAudioClient`](../README.md#createaudioclient) - -*** - -### createChatClient() - -```ts -createChatClient(): ChatClient; -``` - -Creates a ChatClient for interacting with the model via chat completions. - -#### Returns - -[`ChatClient`](ChatClient.md) - -A ChatClient instance. - -#### Implementation of - -[`IModel`](../README.md#imodel).[`createChatClient`](../README.md#createchatclient) - -*** - -### createResponsesClient() - -```ts -createResponsesClient(baseUrl): ResponsesClient; -``` - -Creates a ResponsesClient for interacting with the model via the Responses API. - -#### Parameters - -| Parameter | Type | Description | -| ------ | ------ | ------ | -| `baseUrl` | `string` | The base URL of the Foundry Local web service. | - -#### Returns - -[`ResponsesClient`](ResponsesClient.md) - -A ResponsesClient instance. - -#### Implementation of - -[`IModel`](../README.md#imodel).[`createResponsesClient`](../README.md#createresponsesclient) - -*** - -### download() - -```ts -download(progressCallback?): Promise; -``` - -Downloads the model variant. - -#### Parameters - -| Parameter | Type | Description | -| ------ | ------ | ------ | -| `progressCallback?` | (`progress`) => `void` | Optional callback to report download progress (0-100). | - -#### Returns - -`Promise`\<`void`\> - -#### Implementation of - -[`IModel`](../README.md#imodel).[`download`](../README.md#download) - -*** - -### isLoaded() - -```ts -isLoaded(): Promise; -``` - -Checks if the model variant is loaded in memory. - -#### Returns - -`Promise`\<`boolean`\> - -True if loaded, false otherwise. - -#### Implementation of - -[`IModel`](../README.md#imodel).[`isLoaded`](../README.md#isloaded) - -*** - -### load() - -```ts -load(): Promise; -``` - -Loads the model variant into memory. - -#### Returns - -`Promise`\<`void`\> - -A promise that resolves when the model is loaded. - -#### Implementation of - -[`IModel`](../README.md#imodel).[`load`](../README.md#load) - -*** - -### removeFromCache() - -```ts -removeFromCache(): void; -``` - -Removes the model variant from the local cache. - -#### Returns - -`void` - -#### Implementation of - -[`IModel`](../README.md#imodel).[`removeFromCache`](../README.md#removefromcache) - -*** - -### unload() - -```ts -unload(): Promise; -``` - -Unloads the model variant from memory. - -#### Returns - -`Promise`\<`void`\> - -A promise that resolves when the model is unloaded. - -#### Implementation of - -[`IModel`](../README.md#imodel).[`unload`](../README.md#unload) diff --git a/sdk/js/src/catalog.ts b/sdk/js/src/catalog.ts index 2efba66a..d4331c38 100644 --- a/sdk/js/src/catalog.ts +++ b/sdk/js/src/catalog.ts @@ -1,8 +1,9 @@ import { CoreInterop } from './detail/coreInterop.js'; import { ModelLoadManager } from './detail/modelLoadManager.js'; -import { Model } from './model.js'; -import { ModelVariant } from './modelVariant.js'; +import { Model } from './detail/model.js'; +import { ModelVariant } from './detail/modelVariant.js'; import { ModelInfo } from './types.js'; +import { IModel } from './imodel.js'; /** * Represents a catalog of AI models available in the system. @@ -76,9 +77,9 @@ export class Catalog { /** * Lists all available models in the catalog. * This method is asynchronous as it may fetch the model list from a remote service or perform file I/O. - * @returns A Promise that resolves to an array of Model objects. + * @returns A Promise that resolves to an array of IModel objects. */ - public async getModels(): Promise { + public async getModels(): Promise { await this.updateModels(); return this._models; } @@ -87,10 +88,10 @@ export class Catalog { * Retrieves a model by its alias. * This method is asynchronous as it may ensure the catalog is up-to-date by fetching from a remote service. * @param alias - The alias of the model to retrieve. - * @returns A Promise that resolves to the Model object if found, otherwise throws an error. + * @returns A Promise that resolves to the IModel object if found, otherwise throws an error. * @throws Error - If alias is null, undefined, or empty. */ - public async getModel(alias: string): Promise { + public async getModel(alias: string): Promise { if (typeof alias !== 'string' || alias.trim() === '') { throw new Error('Model alias must be a non-empty string.'); } @@ -105,12 +106,14 @@ export class Catalog { /** * Retrieves a specific model variant by its ID. + * NOTE: This will return an IModel with a single variant. Use getModel to get an IModel with all available + * variants. * This method is asynchronous as it may ensure the catalog is up-to-date by fetching from a remote service. * @param modelId - The unique identifier of the model variant. - * @returns A Promise that resolves to the ModelVariant object if found, otherwise throws an error. + * @returns A Promise that resolves to the IModel object if found, otherwise throws an error. * @throws Error - If modelId is null, undefined, or empty. */ - public async getModelVariant(modelId: string): Promise { + public async getModelVariant(modelId: string): Promise { if (typeof modelId !== 'string' || modelId.trim() === '') { throw new Error('Model ID must be a non-empty string.'); } @@ -126,9 +129,9 @@ export class Catalog { /** * Retrieves a list of all locally cached model variants. * This method is asynchronous as it may involve file I/O or querying the underlying core. - * @returns A Promise that resolves to an array of cached ModelVariant objects. + * @returns A Promise that resolves to an array of cached IModel objects. */ - public async getCachedModels(): Promise { + public async getCachedModels(): Promise { await this.updateModels(); const cachedModelListJson = this.coreInterop.executeCommand("get_cached_models"); let cachedModelIds: string[] = []; @@ -137,7 +140,7 @@ export class Catalog { } catch (error) { throw new Error(`Failed to parse cached model list JSON: ${error}`); } - const cachedModels: Set = new Set(); + const cachedModels: Set = new Set(); for (const modelId of cachedModelIds) { const variant = this.modelIdToModelVariant.get(modelId); @@ -152,9 +155,9 @@ export class Catalog { * Retrieves a list of all currently loaded model variants. * This operation is asynchronous because checking the loaded status may involve querying * the underlying core or an external service, which can be an I/O bound operation. - * @returns A Promise that resolves to an array of loaded ModelVariant objects. + * @returns A Promise that resolves to an array of loaded IModel objects. */ - public async getLoadedModels(): Promise { + public async getLoadedModels(): Promise { await this.updateModels(); let loadedModelIds: string[] = []; try { @@ -162,7 +165,7 @@ export class Catalog { } catch (error) { throw new Error(`Failed to list loaded models: ${error}`); } - const loadedModels: ModelVariant[] = []; + const loadedModels: IModel[] = []; for (const modelId of loadedModelIds) { const variant = this.modelIdToModelVariant.get(modelId); @@ -172,4 +175,33 @@ export class Catalog { } return loadedModels; } + + /** + * Get the latest version of a model. + * This is used to check if a newer version of a model is available in the catalog for download. + * @param modelOrModelVariant - The model to check for the latest version. + * @returns The latest version of the model. Will match the input if it is the latest version. + */ + public async getLatestVersion(modelOrModelVariant: IModel): Promise { + await this.updateModels(); + + // Resolve to the parent Model by alias + const model = this.modelAliasToModel.get(modelOrModelVariant.alias); + if (!model) { + throw new Error(`Model with alias '${modelOrModelVariant.alias}' not found in catalog.`); + } + + // variants are sorted by version, so the first one matching the name is the latest version + const latest = model.variants.find(v => v.info.name === modelOrModelVariant.info.name); + if (!latest) { + throw new Error( + `Internal error. Mismatch between model (alias:${model.alias}) and ` + + `model variant (alias:${modelOrModelVariant.alias}).` + ); + } + + // if input was the latest return the input (could be model or model variant) + // otherwise return the latest model variant + return latest.id === modelOrModelVariant.id ? modelOrModelVariant : latest; + } } \ No newline at end of file diff --git a/sdk/js/src/model.ts b/sdk/js/src/detail/model.ts similarity index 78% rename from sdk/js/src/model.ts rename to sdk/js/src/detail/model.ts index b4f60040..46245ee5 100644 --- a/sdk/js/src/model.ts +++ b/sdk/js/src/detail/model.ts @@ -1,9 +1,10 @@ import { ModelVariant } from './modelVariant.js'; -import { ChatClient } from './openai/chatClient.js'; -import { AudioClient } from './openai/audioClient.js'; -import { LiveAudioTranscriptionSession } from './openai/liveAudioTranscriptionClient.js'; -import { ResponsesClient } from './openai/responsesClient.js'; -import { IModel } from './imodel.js'; +import { ChatClient } from '../openai/chatClient.js'; +import { AudioClient } from '../openai/audioClient.js'; +import { ResponsesClient } from '../openai/responsesClient.js'; +import { LiveAudioTranscriptionSession } from '../openai/liveAudioTranscriptionClient.js'; +import { IModel } from '../imodel.js'; +import { ModelInfo } from '../types.js'; /** * Represents a high-level AI model that may have multiple variants (e.g., quantized versions, different formats). @@ -21,25 +22,14 @@ export class Model implements IModel { this.selectedVariant = variant; } - private validateVariantInput(variant: ModelVariant, caller: string): void { - if (variant === null || variant === undefined) { - throw new Error(`${caller}() requires a ModelVariant object but received ${variant}.`); - } - if (typeof variant !== 'object') { - throw new Error( - `${caller}() requires a ModelVariant object but received ${typeof variant}.` - ); - } - } - /** * Adds a new variant to this model. * Automatically selects the new variant if it is cached and the current one is not. * @param variant - The model variant to add. - * @throws Error - If the argument is not a ModelVariant object, or if the variant's alias does not match the model's alias. + * @throws Error - If the variant's alias does not match the model's alias. + * @internal */ public addVariant(variant: ModelVariant): void { - this.validateVariantInput(variant, 'addVariant'); if (!variant || variant.alias !== this._alias) { throw new Error(`Variant alias "${variant?.alias}" does not match model alias "${this._alias}".`); } @@ -53,14 +43,13 @@ export class Model implements IModel { /** * Selects a specific variant. - * @param variant - The model variant to select. - * @throws Error - If the argument is not a ModelVariant object, or if the variant does not belong to this model. + * @param variant - The model variant to select. Must be one of the variants in `variants`. + * @throws Error - If the variant does not belong to this model. */ - public selectVariant(variant: ModelVariant): void { - this.validateVariantInput(variant, 'selectVariant'); + public selectVariant(variant: IModel): void { const matchingVariant = this._variants.find(v => v.id === variant.id); if (!variant.id || !matchingVariant) { - throw new Error(`Model variant with ID ${variant.id} does not belong to model "${this._alias}".`); + throw new Error(`Input variant was not found in Variants.`); } this.selectedVariant = matchingVariant; } @@ -81,6 +70,14 @@ export class Model implements IModel { return this._alias; } + /** + * Gets the ModelInfo of the currently selected variant. + * @returns The ModelInfo object. + */ + public get info(): ModelInfo { + return this.selectedVariant.info; + } + /** * Checks if the currently selected variant is cached locally. * @returns True if cached, false otherwise. @@ -99,9 +96,9 @@ export class Model implements IModel { /** * Gets all available variants for this model. - * @returns An array of ModelVariant objects. + * @returns An array of IModel objects. */ - public get variants(): ModelVariant[] { + public get variants(): IModel[] { return this._variants; } diff --git a/sdk/js/src/modelVariant.ts b/sdk/js/src/detail/modelVariant.ts similarity index 82% rename from sdk/js/src/modelVariant.ts rename to sdk/js/src/detail/modelVariant.ts index 86c3d3f5..d1c1e20c 100644 --- a/sdk/js/src/modelVariant.ts +++ b/sdk/js/src/detail/modelVariant.ts @@ -1,15 +1,16 @@ -import { CoreInterop } from './detail/coreInterop.js'; -import { ModelLoadManager } from './detail/modelLoadManager.js'; -import { ModelInfo } from './types.js'; -import { ChatClient } from './openai/chatClient.js'; -import { AudioClient } from './openai/audioClient.js'; -import { LiveAudioTranscriptionSession } from './openai/liveAudioTranscriptionClient.js'; -import { ResponsesClient } from './openai/responsesClient.js'; -import { IModel } from './imodel.js'; +import { CoreInterop } from './coreInterop.js'; +import { ModelLoadManager } from './modelLoadManager.js'; +import { ModelInfo } from '../types.js'; +import { ChatClient } from '../openai/chatClient.js'; +import { AudioClient } from '../openai/audioClient.js'; +import { LiveAudioTranscriptionSession } from '../openai/liveAudioTranscriptionClient.js'; +import { ResponsesClient } from '../openai/responsesClient.js'; +import { IModel } from '../imodel.js'; /** * Represents a specific variant of a model (e.g., a specific quantization or format). * Contains the low-level implementation for interacting with the model. + * @internal */ export class ModelVariant implements IModel { private _modelInfo: ModelInfo; @@ -42,10 +43,29 @@ export class ModelVariant implements IModel { * Gets the detailed information about the model variant. * @returns The ModelInfo object. */ - public get modelInfo(): ModelInfo { + public get info(): ModelInfo { return this._modelInfo; } + /** + * A ModelVariant is a single variant, so variants returns itself. + */ + public get variants(): IModel[] { + return [this]; + } + + /** + * SelectVariant is not supported on a ModelVariant. + * Call Catalog.getModel() to get an IModel with all variants available. + * @throws Error always. + */ + public selectVariant(_variant: IModel): void { + throw new Error( + `selectVariant is not supported on a ModelVariant. ` + + `Call Catalog.getModel("${this.alias}") to get an IModel with all variants available.` + ); + } + public get contextLength(): number | null { return this._modelInfo.contextLength ?? null; } diff --git a/sdk/js/src/imodel.ts b/sdk/js/src/imodel.ts index 625afdec..7a2f5a2c 100644 --- a/sdk/js/src/imodel.ts +++ b/sdk/js/src/imodel.ts @@ -2,10 +2,12 @@ import { ChatClient } from './openai/chatClient.js'; import { AudioClient } from './openai/audioClient.js'; import { LiveAudioTranscriptionSession } from './openai/liveAudioTranscriptionClient.js'; import { ResponsesClient } from './openai/responsesClient.js'; +import { ModelInfo } from './types.js'; export interface IModel { get id(): string; get alias(): string; + get info(): ModelInfo; get isCached(): boolean; isLoaded(): Promise; @@ -37,4 +39,17 @@ export interface IModel { * @param baseUrl - The base URL of the Foundry Local web service. */ createResponsesClient(baseUrl: string): ResponsesClient; + + /** + * Variants of the model that are available. Variants of the model are optimized for different devices. + */ + get variants(): IModel[]; + + /** + * Select a model variant from variants to use for IModel operations. + * An IModel from `variants` can also be used directly. + * @param variant - Model variant to select. Must be one of the variants in `variants`. + * @throws Error if variant is not valid for this model. + */ + selectVariant(variant: IModel): void; } diff --git a/sdk/js/src/index.ts b/sdk/js/src/index.ts index 57d9fcf7..42b498c3 100644 --- a/sdk/js/src/index.ts +++ b/sdk/js/src/index.ts @@ -1,8 +1,10 @@ export { FoundryLocalManager } from './foundryLocalManager.js'; export type { FoundryLocalConfig } from './configuration.js'; export { Catalog } from './catalog.js'; -export { Model } from './model.js'; -export { ModelVariant } from './modelVariant.js'; +/** @internal */ +export { Model } from './detail/model.js'; +/** @internal */ +export { ModelVariant } from './detail/modelVariant.js'; export type { IModel } from './imodel.js'; export { ChatClient, ChatClientSettings } from './openai/chatClient.js'; export { AudioClient, AudioClientSettings } from './openai/audioClient.js'; diff --git a/sdk/js/test/catalog.test.ts b/sdk/js/test/catalog.test.ts index df47d4f6..8c320723 100644 --- a/sdk/js/test/catalog.test.ts +++ b/sdk/js/test/catalog.test.ts @@ -1,5 +1,7 @@ import { describe, it } from 'mocha'; import { expect } from 'chai'; +import { Catalog } from '../src/catalog.js'; +import { DeviceType, type ModelInfo } from '../src/types.js'; import { getTestManager, TEST_MODEL_ALIAS } from './testUtils.js'; describe('Catalog Tests', () => { @@ -106,4 +108,97 @@ describe('Catalog Tests', () => { expect((error as Error).message).to.include('Available variants:'); } }); + + it('should resolve latest version for model and variant inputs', async function() { + // Mirror the C# test by using synthetic model data sorted by version descending. + const testModelInfos: ModelInfo[] = [ + { + id: 'test-model:3', + name: 'test-model', + version: 3, + alias: 'test-alias', + displayName: 'Test Model', + providerType: 'test', + uri: 'test://model/3', + modelType: 'ONNX', + runtime: { deviceType: DeviceType.CPU, executionProvider: 'CPUExecutionProvider' }, + cached: false, + createdAtUnix: 1700000003 + }, + { + id: 'test-model:2', + name: 'test-model', + version: 2, + alias: 'test-alias', + displayName: 'Test Model', + providerType: 'test', + uri: 'test://model/2', + modelType: 'ONNX', + runtime: { deviceType: DeviceType.CPU, executionProvider: 'CPUExecutionProvider' }, + cached: false, + createdAtUnix: 1700000002 + }, + { + id: 'test-model:1', + name: 'test-model', + version: 1, + alias: 'test-alias', + displayName: 'Test Model', + providerType: 'test', + uri: 'test://model/1', + modelType: 'ONNX', + runtime: { deviceType: DeviceType.CPU, executionProvider: 'CPUExecutionProvider' }, + cached: false, + createdAtUnix: 1700000001 + } + ]; + + const mockCoreInterop = { + executeCommand(command: string): string { + if (command === 'get_catalog_name') { + return 'TestCatalog'; + } + if (command === 'get_model_list') { + return JSON.stringify(testModelInfos); + } + if (command === 'get_cached_models') { + return '[]'; + } + throw new Error(`Unexpected command: ${command}`); + } + } as any; + + const mockLoadManager = { + listLoaded: async () => [] + } as any; + + const catalog = new Catalog(mockCoreInterop, mockLoadManager); + + const model = await catalog.getModel('test-alias'); + expect(model).to.not.be.undefined; + + const variants = model.variants; + expect(variants).to.have.length(3); + + const latestVariant = variants[0]; + const middleVariant = variants[1]; + const oldestVariant = variants[2]; + + expect(latestVariant.id).to.equal('test-model:3'); + expect(middleVariant.id).to.equal('test-model:2'); + expect(oldestVariant.id).to.equal('test-model:1'); + + const result1 = await catalog.getLatestVersion(latestVariant); + expect(result1.id).to.equal('test-model:3'); + + const result2 = await catalog.getLatestVersion(middleVariant); + expect(result2.id).to.equal('test-model:3'); + + const result3 = await catalog.getLatestVersion(oldestVariant); + expect(result3.id).to.equal('test-model:3'); + + model.selectVariant(latestVariant); + const resultFromModel = await catalog.getLatestVersion(model); + expect(resultFromModel).to.equal(model); + }); }); diff --git a/sdk/js/test/model.test.ts b/sdk/js/test/model.test.ts index acc4d6e2..4048d9a1 100644 --- a/sdk/js/test/model.test.ts +++ b/sdk/js/test/model.test.ts @@ -39,7 +39,12 @@ describe('Model Tests', () => { expect(model).to.not.be.undefined; if (!model || !cachedVariant) return; - model.selectVariant(cachedVariant); + // Select the cached variant by finding it in the model's variants + const matchingVariant = model.variants.find(v => v.id === cachedVariant.id); + expect(matchingVariant).to.not.be.undefined; + if (matchingVariant) { + model.selectVariant(matchingVariant); + } // Ensure it's not loaded initially (or unload if it is) if (await model.isLoaded()) { diff --git a/sdk/js/test/openai/responsesClient.test.ts b/sdk/js/test/openai/responsesClient.test.ts index 925a2360..f0dbf4b0 100644 --- a/sdk/js/test/openai/responsesClient.test.ts +++ b/sdk/js/test/openai/responsesClient.test.ts @@ -10,7 +10,7 @@ import type { MessageItem, } from '../../src/types.js'; import { FoundryLocalManager } from '../../src/foundryLocalManager.js'; -import { Model } from '../../src/model.js'; +import type { IModel } from '../../src/imodel.js'; describe('ResponsesClient Tests', () => { @@ -371,7 +371,7 @@ describe('ResponsesClient Tests', () => { describe('Integration (requires model + web service)', function() { let manager: FoundryLocalManager; - let model: Model; + let model: IModel; let client: ResponsesClient; let skipped = false; diff --git a/sdk/python/README.md b/sdk/python/README.md index 4c1fb84a..4ee1f9cc 100644 --- a/sdk/python/README.md +++ b/sdk/python/README.md @@ -184,7 +184,7 @@ loaded = catalog.get_loaded_models() ### Inspecting Model Metadata -`Model` exposes metadata properties from the catalog: +`IModel` exposes metadata properties from the catalog: ```python model = catalog.get_model("phi-3.5-mini") @@ -268,8 +268,7 @@ manager.stop_web_service() | `EpInfo` | Discoverable execution provider info (`name`, `is_registered`) | | `EpDownloadResult` | Result of EP download/registration (`success`, `status`, `registered_eps`, `failed_eps`) | | `Catalog` | Model discovery – listing, lookup by alias/ID, cached/loaded queries | -| `Model` | Groups variants under one alias – select, load, unload, create clients | -| `ModelVariant` | Specific model variant – download, cache, load/unload, create clients | +| `IModel` | Abstract interface for models — identity, metadata, lifecycle, client creation, variant selection | ### OpenAI Clients @@ -282,6 +281,8 @@ manager.stop_web_service() | Class | Description | |---|---| +| `Model` | Alias-level `IModel` implementation used by `Catalog.get_model()` (implementation detail) | +| `ModelVariant` | Specific model variant (implementation detail — implements `IModel`) | | `CoreInterop` | ctypes FFI layer to the native Foundry Local Core library | | `ModelLoadManager` | Load/unload via core interop or external web service | | `ModelInfo` | Pydantic model for catalog entries | diff --git a/sdk/python/src/catalog.py b/sdk/python/src/catalog.py index afccd85b..51f5bd8f 100644 --- a/sdk/python/src/catalog.py +++ b/sdk/python/src/catalog.py @@ -11,8 +11,9 @@ from typing import List, Optional from pydantic import TypeAdapter -from .model import Model -from .model_variant import ModelVariant +from .imodel import IModel +from .detail.model import Model +from .detail.model_variant import ModelVariant from .detail.core_interop import CoreInterop, get_cached_model_ids from .detail.model_data_types import ModelInfo @@ -87,42 +88,72 @@ def _invalidate_cache(self): with self._lock: self._last_fetch = datetime.datetime.min - def list_models(self) -> List[Model]: + def list_models(self) -> List[IModel]: """ List the available models in the catalog. - :return: List of Model instances. + :return: List of IModel instances. """ self._update_models() return list(self._model_alias_to_model.values()) - def get_model(self, model_alias: str) -> Optional[Model]: + def get_model(self, model_alias: str) -> Optional[IModel]: """ Lookup a model by its alias. :param model_alias: Model alias. - :return: Model if found. + :return: IModel if found. """ self._update_models() return self._model_alias_to_model.get(model_alias) - def get_model_variant(self, model_id: str) -> Optional[ModelVariant]: + def get_model_variant(self, model_id: str) -> Optional[IModel]: """ Lookup a model variant by its unique model id. + NOTE: This will return an IModel with a single variant. Use get_model to get an IModel with all available + variants. :param model_id: Model id. - :return: Model variant if found. + :return: IModel if found. """ self._update_models() return self._model_id_to_model_variant.get(model_id) - def get_cached_models(self) -> List[ModelVariant]: + def get_latest_version(self, model_or_model_variant: IModel) -> IModel: + """ + Resolve the latest catalog version for the provided model or variant. + + :param model_or_model_variant: IModel to resolve. + :return: Latest catalog version for the same model name. + :raises FoundryLocalException: If the alias or name cannot be resolved. + """ + self._update_models() + + model = self._model_alias_to_model.get(model_or_model_variant.alias) + if model is None: + raise FoundryLocalException( + f"Model with alias '{model_or_model_variant.alias}' not found in catalog." + ) + + latest = next( + (variant for variant in model.variants if variant.info.name == model_or_model_variant.info.name), + None, + ) + if latest is None: + raise FoundryLocalException( + f"Internal error. Mismatch between model (alias:{model.alias}) and " + f"model variant (alias:{model_or_model_variant.alias})." + ) + + return model_or_model_variant if latest.id == model_or_model_variant.id else latest + + def get_cached_models(self) -> List[IModel]: """ Get a list of currently downloaded models from the model cache. - :return: List of ModelVariant instances. + :return: List of IModel instances. """ self._update_models() cached_model_ids = get_cached_model_ids(self._core_interop) - cached_models = [] + cached_models: List[IModel] = [] for model_id in cached_model_ids: model_variant = self._model_id_to_model_variant.get(model_id) if model_variant is not None: @@ -130,15 +161,15 @@ def get_cached_models(self) -> List[ModelVariant]: return cached_models - def get_loaded_models(self) -> List[ModelVariant]: + def get_loaded_models(self) -> List[IModel]: """ Get a list of the currently loaded models. - :return: List of ModelVariant instances. + :return: List of IModel instances. """ self._update_models() loaded_model_ids = self._model_load_manager.list_loaded() - loaded_models = [] + loaded_models: List[IModel] = [] for model_id in loaded_model_ids: model_variant = self._model_id_to_model_variant.get(model_id) diff --git a/sdk/python/src/model.py b/sdk/python/src/detail/model.py similarity index 71% rename from sdk/python/src/model.py rename to sdk/python/src/detail/model.py index f964a820..189920b1 100644 --- a/sdk/python/src/model.py +++ b/sdk/python/src/detail/model.py @@ -7,18 +7,19 @@ import logging from typing import Callable, List, Optional -from .imodel import IModel -from .openai.chat_client import ChatClient -from .openai.audio_client import AudioClient +from ..imodel import IModel +from ..openai.chat_client import ChatClient +from ..openai.audio_client import AudioClient from .model_variant import ModelVariant -from .exception import FoundryLocalException -from .detail.core_interop import CoreInterop +from ..exception import FoundryLocalException +from .core_interop import CoreInterop +from .model_data_types import ModelInfo logger = logging.getLogger(__name__) class Model(IModel): - """A model identified by an alias that groups one or more ``ModelVariant`` instances. + """A model identified by an alias that groups one or more variants. Operations are delegated to the currently selected variant. """ @@ -42,47 +43,26 @@ def _add_variant(self, variant: ModelVariant) -> None: if variant.info.cached and not self._selected_variant.info.cached: self._selected_variant = variant - def select_variant(self, variant: ModelVariant) -> None: + def select_variant(self, variant: IModel) -> None: """ - Select a specific model variant by its ModelVariant object. - The selected variant will be used for IModel operations. - - :param variant: ModelVariant to select + Select a specific model variant to use for IModel operations. + An IModel from ``variants`` can also be used directly. + + :param variant: IModel to select. Must be one of the variants in ``variants``. :raises FoundryLocalException: If variant is not valid for this model """ - if variant not in self._variants: + matching = next((v for v in self._variants if v.id == variant.id), None) + if matching is None: raise FoundryLocalException( - f"Model {self._alias} does not have a {variant.id} variant." + "Input variant was not found in Variants." ) - self._selected_variant = variant - - def get_latest_version(self, variant: ModelVariant) -> ModelVariant: - """ - Get the latest version of the specified model variant. - - :param variant: Model variant - :return: ModelVariant for latest version. Same as variant if that is the latest version - :raises FoundryLocalException: If variant is not valid for this model - """ - # Variants are sorted by version, so the first one matching the name is the latest version - for v in self._variants: - if v.info.name == variant.info.name: - return v - - raise FoundryLocalException( - f"Model {self._alias} does not have a {variant.id} variant." - ) + self._selected_variant = matching @property - def variants(self) -> List[ModelVariant]: + def variants(self) -> List[IModel]: """List of all variants for this model.""" - return self._variants.copy() # Return a copy to prevent external modification - - @property - def selected_variant(self) -> ModelVariant: - """Currently selected variant.""" - return self._selected_variant + return list(self._variants) # Return a copy to prevent external modification @property def id(self) -> str: @@ -94,6 +74,11 @@ def alias(self) -> str: """Alias of this model.""" return self._alias + @property + def info(self) -> ModelInfo: + """ModelInfo of the currently selected variant.""" + return self._selected_variant.info + @property def context_length(self) -> Optional[int]: """Maximum context length (in tokens) of the currently selected variant.""" diff --git a/sdk/python/src/detail/model_data_types.py b/sdk/python/src/detail/model_data_types.py index df367b44..46525dc7 100644 --- a/sdk/python/src/detail/model_data_types.py +++ b/sdk/python/src/detail/model_data_types.py @@ -57,24 +57,24 @@ class ModelInfo(BaseModel): name: str = Field(alias="name", description="Model variant name") version: int = Field(alias="version") alias: str = Field(..., description="Alias of the model") - display_name: Optional[str] = Field(alias="displayName") + display_name: Optional[str] = Field(default=None, alias="displayName") provider_type: str = Field(alias="providerType") uri: str = Field(alias="uri") model_type: str = Field(alias="modelType") prompt_template: Optional[PromptTemplate] = Field(default=None, alias="promptTemplate") - publisher: Optional[str] = Field(alias="publisher") + publisher: Optional[str] = Field(default=None, alias="publisher") model_settings: Optional[ModelSettings] = Field(default=None, alias="modelSettings") - license: Optional[str] = Field(alias="license") - license_description: Optional[str] = Field(alias="licenseDescription") + license: Optional[str] = Field(default=None, alias="license") + license_description: Optional[str] = Field(default=None, alias="licenseDescription") cached: bool = Field(alias="cached") - task: Optional[str] = Field(alias="task") - runtime: Optional[Runtime] = Field(alias="runtime") - file_size_mb: Optional[int] = Field(alias="fileSizeMb") - supports_tool_calling: Optional[bool] = Field(alias="supportsToolCalling") - max_output_tokens: Optional[int] = Field(alias="maxOutputTokens") - min_fl_version: Optional[str] = Field(alias="minFLVersion") + task: Optional[str] = Field(default=None, alias="task") + runtime: Optional[Runtime] = Field(default=None, alias="runtime") + file_size_mb: Optional[int] = Field(default=None, alias="fileSizeMb") + supports_tool_calling: Optional[bool] = Field(default=None, alias="supportsToolCalling") + max_output_tokens: Optional[int] = Field(default=None, alias="maxOutputTokens") + min_fl_version: Optional[str] = Field(default=None, alias="minFLVersion") created_at_unix: int = Field(alias="createdAt") - context_length: Optional[int] = Field(alias="contextLength") - input_modalities: Optional[str] = Field(alias="inputModalities") - output_modalities: Optional[str] = Field(alias="outputModalities") - capabilities: Optional[str] = Field(alias="capabilities") + context_length: Optional[int] = Field(default=None, alias="contextLength") + input_modalities: Optional[str] = Field(default=None, alias="inputModalities") + output_modalities: Optional[str] = Field(default=None, alias="outputModalities") + capabilities: Optional[str] = Field(default=None, alias="capabilities") diff --git a/sdk/python/src/model_variant.py b/sdk/python/src/detail/model_variant.py similarity index 84% rename from sdk/python/src/model_variant.py rename to sdk/python/src/detail/model_variant.py index 1c7ad717..a5ac02d4 100644 --- a/sdk/python/src/model_variant.py +++ b/sdk/python/src/detail/model_variant.py @@ -5,17 +5,17 @@ from __future__ import annotations import logging -from typing import Callable, Optional +from typing import Callable, List, Optional -from .imodel import IModel -from .exception import FoundryLocalException +from ..imodel import IModel +from ..exception import FoundryLocalException -from .detail.core_interop import CoreInterop, InteropRequest -from .detail.model_data_types import ModelInfo -from .detail.core_interop import get_cached_model_ids -from .detail.model_load_manager import ModelLoadManager -from .openai.audio_client import AudioClient -from .openai.chat_client import ChatClient +from .core_interop import CoreInterop, InteropRequest +from .model_data_types import ModelInfo +from .core_interop import get_cached_model_ids +from .model_load_manager import ModelLoadManager +from ..openai.audio_client import AudioClient +from ..openai.chat_client import ChatClient logger = logging.getLogger(__name__) @@ -62,6 +62,23 @@ def context_length(self) -> Optional[int]: """Maximum context length (in tokens) supported by this variant, or ``None`` if unknown.""" return self._model_info.context_length + @property + def variants(self) -> List[IModel]: + """A ModelVariant is a single variant, so variants returns itself.""" + return [self] + + def select_variant(self, variant: IModel) -> None: + """SelectVariant is not supported on a ModelVariant. + + Call ``Catalog.get_model()`` to get an IModel with all variants available. + + :raises FoundryLocalException: Always. + """ + raise FoundryLocalException( + f"select_variant is not supported on a ModelVariant. " + f'Call Catalog.get_model("{self._alias}") to get an IModel with all variants available.' + ) + @property def input_modalities(self) -> Optional[str]: """Comma-separated input modalities (e.g. ``"text,image"``), or ``None`` if unknown.""" diff --git a/sdk/python/src/imodel.py b/sdk/python/src/imodel.py index 7f83d1cc..8237aeb4 100644 --- a/sdk/python/src/imodel.py +++ b/sdk/python/src/imodel.py @@ -5,10 +5,11 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Callable, Optional +from typing import Callable, List, Optional from .openai.chat_client import ChatClient from .openai.audio_client import AudioClient +from .detail.model_data_types import ModelInfo class IModel(ABC): """Abstract interface for a model that can be downloaded, loaded, and used for inference.""" @@ -25,6 +26,12 @@ def alias(self) -> str: """Model alias.""" pass + @property + @abstractmethod + def info(self) -> ModelInfo: + """Full model metadata.""" + pass + @property @abstractmethod def is_cached(self) -> bool: @@ -119,3 +126,20 @@ def get_audio_client(self) -> AudioClient: :return: AudioClient instance. """ pass + + @property + @abstractmethod + def variants(self) -> List['IModel']: + """Variants of the model that are available. Variants of the model are optimized for different devices.""" + pass + + @abstractmethod + def select_variant(self, variant: 'IModel') -> None: + """ + Select a model variant from ``variants`` to use for IModel operations. + An IModel from ``variants`` can also be used directly. + + :param variant: Model variant to select. Must be one of the variants in ``variants``. + :raises FoundryLocalException: If variant is not valid for this model. + """ + pass diff --git a/sdk/python/test/test_catalog.py b/sdk/python/test/test_catalog.py index aeb39c20..2e5968cc 100644 --- a/sdk/python/test/test_catalog.py +++ b/sdk/python/test/test_catalog.py @@ -6,6 +6,11 @@ from __future__ import annotations +import json + +from foundry_local_sdk.catalog import Catalog +from foundry_local_sdk.detail.core_interop import Response + from .conftest import TEST_MODEL_ALIAS @@ -72,3 +77,91 @@ def test_should_return_none_for_unknown_variant_id(self, catalog): """get_model_variant() with a random ID should return None.""" result = catalog.get_model_variant("definitely-not-a-real-model-id-12345") assert result is None + + def test_should_resolve_latest_version_for_model_and_variant_inputs(self): + """get_latest_version() should resolve latest variant and preserve Model input when already latest.""" + + test_model_infos = [ + { + "id": "test-model:3", + "name": "test-model", + "version": 3, + "alias": "test-alias", + "displayName": "Test Model", + "providerType": "test", + "uri": "test://model/3", + "modelType": "ONNX", + "runtime": {"deviceType": "CPU", "executionProvider": "CPUExecutionProvider"}, + "cached": False, + "createdAt": 1700000003, + }, + { + "id": "test-model:2", + "name": "test-model", + "version": 2, + "alias": "test-alias", + "displayName": "Test Model", + "providerType": "test", + "uri": "test://model/2", + "modelType": "ONNX", + "runtime": {"deviceType": "CPU", "executionProvider": "CPUExecutionProvider"}, + "cached": False, + "createdAt": 1700000002, + }, + { + "id": "test-model:1", + "name": "test-model", + "version": 1, + "alias": "test-alias", + "displayName": "Test Model", + "providerType": "test", + "uri": "test://model/1", + "modelType": "ONNX", + "runtime": {"deviceType": "CPU", "executionProvider": "CPUExecutionProvider"}, + "cached": False, + "createdAt": 1700000001, + }, + ] + + class _MockCoreInterop: + def execute_command(self, command_name, command_input=None): + if command_name == "get_catalog_name": + return Response(data="TestCatalog", error=None) + if command_name == "get_model_list": + return Response(data=json.dumps(test_model_infos), error=None) + if command_name == "get_cached_models": + return Response(data="[]", error=None) + return Response(data=None, error=f"Unexpected command: {command_name}") + + class _MockModelLoadManager: + def list_loaded(self): + return [] + + catalog = Catalog(_MockModelLoadManager(), _MockCoreInterop()) + + model = catalog.get_model("test-alias") + assert model is not None + + variants = model.variants + assert len(variants) == 3 + + latest_variant = variants[0] + middle_variant = variants[1] + oldest_variant = variants[2] + + assert latest_variant.id == "test-model:3" + assert middle_variant.id == "test-model:2" + assert oldest_variant.id == "test-model:1" + + result1 = catalog.get_latest_version(latest_variant) + assert result1.id == "test-model:3" + + result2 = catalog.get_latest_version(middle_variant) + assert result2.id == "test-model:3" + + result3 = catalog.get_latest_version(oldest_variant) + assert result3.id == "test-model:3" + + model.select_variant(latest_variant) + result4 = catalog.get_latest_version(model) + assert result4 is model diff --git a/sdk/rust/README.md b/sdk/rust/README.md index aa848b03..6bcb9884 100644 --- a/sdk/rust/README.md +++ b/sdk/rust/README.md @@ -177,15 +177,15 @@ let loaded = catalog.get_loaded_models().await?; ### Model Lifecycle -Each `Model` wraps one or more `ModelVariant` entries (different quantizations, hardware targets). The SDK auto-selects the best available variant, preferring cached versions. +Each model may have multiple variants (different quantizations, hardware targets). The SDK auto-selects the best available variant, preferring cached versions. All models implement the `IModel` trait. ```rust let model = catalog.get_model("phi-3.5-mini").await?; // Inspect available variants -println!("Selected: {}", model.selected_variant().id()); +println!("Selected: {}", model.id()); for v in model.variants() { - println!(" {} (cached: {})", v.id(), v.info().cached); + println!(" {} (info.cached: {})", v.id(), v.info().cached); } ``` @@ -193,10 +193,10 @@ Download, load, and unload: ```rust // Download with progress reporting -model.download(Some(|progress: &str| { +model.download(Some(Box::new(|progress: &str| { print!("\r{progress}"); std::io::Write::flush(&mut std::io::stdout()).ok(); -})).await?; +}))).await?; // Load into memory model.load().await?; diff --git a/sdk/rust/examples/tool_calling.rs b/sdk/rust/examples/tool_calling.rs index 192b9ff0..fecf6bc5 100644 --- a/sdk/rust/examples/tool_calling.rs +++ b/sdk/rust/examples/tool_calling.rs @@ -61,7 +61,7 @@ async fn main() -> Result<()> { let models = manager.catalog().get_models().await?; let model = models .iter() - .find(|m| m.selected_variant().info().supports_tool_calling == Some(true)) + .find(|m| m.info().supports_tool_calling == Some(true)) .or_else(|| models.first()) .expect("No models available"); diff --git a/sdk/rust/src/catalog.rs b/sdk/rust/src/catalog.rs index d9d5bb51..26a737e9 100644 --- a/sdk/rust/src/catalog.rs +++ b/sdk/rust/src/catalog.rs @@ -6,10 +6,10 @@ use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use crate::detail::core_interop::CoreInterop; +use crate::detail::model::Model; +use crate::detail::model_variant::ModelVariant; use crate::detail::ModelLoadManager; use crate::error::{FoundryLocalError, Result}; -use crate::model::Model; -use crate::model_variant::ModelVariant; use crate::types::ModelInfo; /// How long the catalog cache remains valid before a refresh. @@ -39,7 +39,7 @@ impl CacheInvalidator { /// All mutable catalog data behind a single lock to prevent split-brain reads. struct CatalogState { models_by_alias: HashMap>, - variants_by_id: HashMap>, + variants_by_id: HashMap>, last_refresh: Option, } @@ -148,7 +148,11 @@ impl Catalog { } /// Look up a specific model variant by its unique id. - pub async fn get_model_variant(&self, id: &str) -> Result> { + /// + /// NOTE: This will return a `Model` representing a single variant. Use + /// [`get_model`](Catalog::get_model) to obtain a `Model` with all + /// available variants. + pub async fn get_model_variant(&self, id: &str) -> Result> { if id.trim().is_empty() { return Err(FoundryLocalError::Validation { reason: "Variant id must be a non-empty string".into(), @@ -165,7 +169,7 @@ impl Catalog { } /// Return only the model variants that are currently cached on disk. - pub async fn get_cached_models(&self) -> Result>> { + pub async fn get_cached_models(&self) -> Result>> { self.update_models().await?; let raw = self .core @@ -183,7 +187,7 @@ impl Catalog { } /// Return model variants that are currently loaded into memory. - pub async fn get_loaded_models(&self) -> Result>> { + pub async fn get_loaded_models(&self) -> Result>> { self.update_models().await?; let loaded_ids = self.model_load_manager.list_loaded().await?; let s = self.lock_state()?; @@ -193,6 +197,36 @@ impl Catalog { .collect()) } + /// Resolve the latest catalog version for the provided model or variant. + pub async fn get_latest_version(&self, model_or_model_variant: &Model) -> Result> { + self.update_models().await?; + let s = self.lock_state()?; + + let model = s + .models_by_alias + .get(model_or_model_variant.alias()) + .ok_or_else(|| FoundryLocalError::ModelOperation { + reason: format!( + "Model with alias '{}' not found in catalog.", + model_or_model_variant.alias() + ), + })?; + + let latest = model + .variants() + .into_iter() + .find(|variant| variant.info().name == model_or_model_variant.info().name) + .ok_or_else(|| FoundryLocalError::Internal { + reason: format!( + "Mismatch between model (alias:{}) and model variant (alias:{}).", + model.alias(), + model_or_model_variant.alias() + ), + })?; + + Ok(latest) + } + async fn force_refresh(&self) -> Result<()> { let raw = self .core @@ -216,22 +250,22 @@ impl Catalog { }; let mut alias_map_build: HashMap = HashMap::new(); - let mut id_map: HashMap> = HashMap::new(); + let mut id_map: HashMap> = HashMap::new(); for info in infos { let id = info.id.clone(); let alias = info.alias.clone(); - let variant = Arc::new(ModelVariant::new( + let variant = ModelVariant::new( info, Arc::clone(&self.core), Arc::clone(&self.model_load_manager), self.invalidator.clone(), - )); - id_map.insert(id, Arc::clone(&variant)); + ); + id_map.insert(id, Arc::new(Model::from_variant(variant.clone()))); alias_map_build .entry(alias) - .or_insert_with_key(|a| Model::new(a.clone(), Arc::clone(&self.core))) + .or_insert_with_key(|a| Model::from_group(a.clone(), Arc::clone(&self.core))) .add_variant(variant); } diff --git a/sdk/rust/src/detail/mod.rs b/sdk/rust/src/detail/mod.rs index c7f2fd32..b153ed5b 100644 --- a/sdk/rust/src/detail/mod.rs +++ b/sdk/rust/src/detail/mod.rs @@ -1,4 +1,6 @@ pub(crate) mod core_interop; +pub(crate) mod model; mod model_load_manager; +pub(crate) mod model_variant; pub use self::model_load_manager::ModelLoadManager; diff --git a/sdk/rust/src/detail/model.rs b/sdk/rust/src/detail/model.rs new file mode 100644 index 00000000..196ebe35 --- /dev/null +++ b/sdk/rust/src/detail/model.rs @@ -0,0 +1,300 @@ +//! Public model type backed by an internal enum. +//! +//! Users interact solely with [`Model`]. The internal representation +//! distinguishes between a single variant and a group of variants sharing +//! the same alias, but callers never need to know which kind they hold. + +use std::fmt; +use std::path::PathBuf; +use std::sync::atomic::{AtomicUsize, Ordering::Relaxed}; +use std::sync::Arc; + +use super::core_interop::CoreInterop; +use super::model_variant::ModelVariant; +use crate::error::{FoundryLocalError, Result}; +use crate::openai::AudioClient; +use crate::openai::ChatClient; +use crate::types::ModelInfo; + +/// The public model type. +/// +/// A `Model` may represent either a group of variants (as returned by +/// [`Catalog::get_model`](crate::Catalog::get_model)) or a single variant (as +/// returned by [`Catalog::get_model_variant`](crate::Catalog::get_model_variant) +/// or [`Model::variants`]). +/// +/// When a `Model` groups multiple variants, operations are forwarded to +/// the currently selected variant. Use [`variants`](Model::variants) to +/// inspect the available variants and [`select_variant`](Model::select_variant) +/// to change the selection. +pub struct Model { + inner: ModelKind, +} + +#[allow(clippy::large_enum_variant)] +enum ModelKind { + /// A single model variant (from `get_model_variant` or `variants()`). + ModelVariant(ModelVariant), + /// A group of variants sharing the same alias (from `get_model`). + Model { + alias: String, + core: Arc, + variants: Vec, + selected: AtomicUsize, + }, +} + +impl Clone for Model { + fn clone(&self) -> Self { + Self { + inner: match &self.inner { + ModelKind::ModelVariant(v) => ModelKind::ModelVariant(v.clone()), + ModelKind::Model { + alias, + core, + variants, + selected, + } => ModelKind::Model { + alias: alias.clone(), + core: Arc::clone(core), + variants: variants.clone(), + selected: AtomicUsize::new(selected.load(Relaxed)), + }, + }, + } + } +} + +impl fmt::Debug for Model { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.inner { + ModelKind::ModelVariant(v) => f + .debug_struct("Model::ModelVariant") + .field("id", &v.id()) + .field("alias", &v.alias()) + .finish(), + ModelKind::Model { + alias, + variants, + selected, + .. + } => f + .debug_struct("Model::Model") + .field("alias", alias) + .field("id", &variants[selected.load(Relaxed)].id()) + .field("variants_count", &variants.len()) + .field("selected_index", &selected.load(Relaxed)) + .finish(), + } + } +} + +// ── Construction (crate-internal) ──────────────────────────────────────────── + +impl Model { + /// Create a `Model` wrapping a single variant. + pub(crate) fn from_variant(variant: ModelVariant) -> Self { + Self { + inner: ModelKind::ModelVariant(variant), + } + } + + /// Create a `Model` grouping multiple variants under one alias. + pub(crate) fn from_group(alias: String, core: Arc) -> Self { + Self { + inner: ModelKind::Model { + alias, + core, + variants: Vec::new(), + selected: AtomicUsize::new(0), + }, + } + } + + /// Add a variant to a group. Panics if called on a `ModelVariant` kind. + /// + /// If the new variant is cached and the current selection is not, the new + /// variant becomes the selected one. + pub(crate) fn add_variant(&mut self, variant: ModelVariant) { + match &mut self.inner { + ModelKind::Model { + variants, selected, .. + } => { + variants.push(variant); + let new_idx = variants.len() - 1; + let current = selected.load(Relaxed); + if variants[new_idx].info_ref().cached && !variants[current].info_ref().cached { + selected.store(new_idx, Relaxed); + } + } + ModelKind::ModelVariant(_) => { + panic!("add_variant called on a single-variant Model"); + } + } + } +} + +// ── Private helpers ────────────────────────────────────────────────────────── + +impl Model { + fn selected_variant(&self) -> &ModelVariant { + match &self.inner { + ModelKind::ModelVariant(v) => v, + ModelKind::Model { + variants, selected, .. + } => &variants[selected.load(Relaxed)], + } + } +} + +// ── Public API ─────────────────────────────────────────────────────────────── + +impl Model { + /// Unique identifier of the (selected) variant. + pub fn id(&self) -> &str { + self.selected_variant().id() + } + + /// Alias shared by all variants of this model. + pub fn alias(&self) -> &str { + match &self.inner { + ModelKind::ModelVariant(v) => v.alias(), + ModelKind::Model { alias, .. } => alias, + } + } + + /// Full catalog metadata for the (selected) variant. + pub fn info(&self) -> &ModelInfo { + self.selected_variant().info() + } + + /// Maximum context length (in tokens), or `None` if unknown. + pub fn context_length(&self) -> Option { + self.selected_variant().info().context_length + } + + /// Comma-separated input modalities (e.g. `"text,image"`), or `None`. + pub fn input_modalities(&self) -> Option<&str> { + self.selected_variant().info().input_modalities.as_deref() + } + + /// Comma-separated output modalities (e.g. `"text"`), or `None`. + pub fn output_modalities(&self) -> Option<&str> { + self.selected_variant().info().output_modalities.as_deref() + } + + /// Capability tags (e.g. `"reasoning"`), or `None`. + pub fn capabilities(&self) -> Option<&str> { + self.selected_variant().info().capabilities.as_deref() + } + + /// Whether the model supports tool/function calling, or `None`. + pub fn supports_tool_calling(&self) -> Option { + self.selected_variant().info().supports_tool_calling + } + + /// Whether the (selected) variant is cached on disk. + pub async fn is_cached(&self) -> Result { + self.selected_variant().is_cached().await + } + + /// Whether the (selected) variant is loaded into memory. + pub async fn is_loaded(&self) -> Result { + self.selected_variant().is_loaded().await + } + + /// Download the (selected) variant. If `progress` is provided it + /// receives human-readable progress strings as they arrive. + pub async fn download(&self, progress: Option) -> Result<()> + where + F: FnMut(&str) + Send + 'static, + { + self.selected_variant().download(progress).await + } + + /// Return the local file-system path of the (selected) variant. + pub async fn path(&self) -> Result { + self.selected_variant().path().await + } + + /// Load the (selected) variant into memory. + pub async fn load(&self) -> Result<()> { + self.selected_variant().load().await + } + + /// Unload the (selected) variant from memory. + pub async fn unload(&self) -> Result { + self.selected_variant().unload().await + } + + /// Remove the (selected) variant from the local cache. + pub async fn remove_from_cache(&self) -> Result { + self.selected_variant().remove_from_cache().await + } + + /// Create a [`ChatClient`] bound to the (selected) variant. + pub fn create_chat_client(&self) -> ChatClient { + self.selected_variant().create_chat_client() + } + + /// Create an [`AudioClient`] bound to the (selected) variant. + pub fn create_audio_client(&self) -> AudioClient { + self.selected_variant().create_audio_client() + } + + /// Available variants of this model. + /// + /// For a single-variant model (e.g. from + /// [`Catalog::get_model_variant`](crate::Catalog::get_model_variant)), + /// this returns a single-element list containing itself. + pub fn variants(&self) -> Vec> { + match &self.inner { + ModelKind::ModelVariant(v) => { + vec![Arc::new(Model::from_variant(v.clone()))] + } + ModelKind::Model { variants, .. } => variants + .iter() + .map(|v| Arc::new(Model::from_variant(v.clone()))) + .collect(), + } + } + + /// Select a variant by its unique id. + /// + /// # Errors + /// + /// Returns an error if no variant with the given id exists. + /// For single-variant models this always returns an error — use + /// [`Catalog::get_model`](crate::Catalog::get_model) to obtain a model + /// with all variants available. + pub fn select_variant(&self, id: &str) -> Result<()> { + match &self.inner { + ModelKind::ModelVariant(v) => Err(FoundryLocalError::ModelOperation { + reason: format!( + "select_variant is not supported on a single variant. \ + Call Catalog::get_model(\"{}\") to get a model with all variants available.", + v.alias() + ), + }), + ModelKind::Model { + variants, + selected, + alias, + .. + } => match variants.iter().position(|v| v.id() == id) { + Some(pos) => { + selected.store(pos, Relaxed); + Ok(()) + } + None => { + let available: Vec<&str> = variants.iter().map(|v| v.id()).collect(); + Err(FoundryLocalError::ModelOperation { + reason: format!( + "Variant '{id}' not found for model '{alias}'. Available: {available:?}", + ), + }) + } + }, + } + } +} diff --git a/sdk/rust/src/model_variant.rs b/sdk/rust/src/detail/model_variant.rs similarity index 63% rename from sdk/rust/src/model_variant.rs rename to sdk/rust/src/detail/model_variant.rs index 760306f6..636c5d5b 100644 --- a/sdk/rust/src/model_variant.rs +++ b/sdk/rust/src/detail/model_variant.rs @@ -1,4 +1,7 @@ //! A single model variant backed by [`ModelInfo`]. +//! +//! This type is an implementation detail. Public APIs return +//! [`Arc`](crate::Model) instead. use std::fmt; use std::path::PathBuf; @@ -6,9 +9,9 @@ use std::sync::Arc; use serde_json::json; +use super::core_interop::CoreInterop; +use super::ModelLoadManager; use crate::catalog::CacheInvalidator; -use crate::detail::core_interop::CoreInterop; -use crate::detail::ModelLoadManager; use crate::error::Result; use crate::openai::AudioClient; use crate::openai::ChatClient; @@ -16,8 +19,10 @@ use crate::types::ModelInfo; /// Represents one specific variant of a model (a particular id within an alias /// group). +/// +/// This is an implementation detail — callers should use [`Model`](crate::Model). #[derive(Clone)] -pub struct ModelVariant { +pub(crate) struct ModelVariant { info: ModelInfo, core: Arc, model_load_manager: Arc, @@ -27,8 +32,8 @@ pub struct ModelVariant { impl fmt::Debug for ModelVariant { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ModelVariant") - .field("id", &self.id()) - .field("alias", &self.alias()) + .field("id", &self.info.id) + .field("alias", &self.info.alias) .finish() } } @@ -48,28 +53,23 @@ impl ModelVariant { } } - /// The full [`ModelInfo`] metadata for this variant. - pub fn info(&self) -> &ModelInfo { - &self.info - } - - /// Unique identifier. - pub fn id(&self) -> &str { + pub(crate) fn id(&self) -> &str { &self.info.id } - /// Alias shared with sibling variants. - pub fn alias(&self) -> &str { + pub(crate) fn alias(&self) -> &str { &self.info.alias } - /// Check whether the variant is cached locally by querying the native - /// core. - /// - /// Each call performs a full IPC round-trip. When checking many variants, - /// prefer [`Catalog::get_cached_models`] which fetches the full list in a - /// single call. - pub async fn is_cached(&self) -> Result { + pub(crate) fn info(&self) -> &ModelInfo { + &self.info + } + + pub(crate) fn info_ref(&self) -> &ModelInfo { + &self.info + } + + pub(crate) async fn is_cached(&self) -> Result { let raw = self .core .execute_command_async("get_cached_models".into(), None) @@ -81,15 +81,12 @@ impl ModelVariant { Ok(cached_ids.iter().any(|id| id == &self.info.id)) } - /// Check whether the variant is currently loaded into memory. - pub async fn is_loaded(&self) -> Result { + pub(crate) async fn is_loaded(&self) -> Result { let loaded = self.model_load_manager.list_loaded().await?; Ok(loaded.iter().any(|id| id == &self.info.id)) } - /// Download the model variant. If `progress` is provided, it receives - /// human-readable progress strings as the download proceeds. - pub async fn download(&self, progress: Option) -> Result<()> + pub(crate) async fn download(&self, progress: Option) -> Result<()> where F: FnMut(&str) + Send + 'static, { @@ -110,8 +107,7 @@ impl ModelVariant { Ok(()) } - /// Return the local file-system path where this variant is stored. - pub async fn path(&self) -> Result { + pub(crate) async fn path(&self) -> Result { let params = json!({ "Params": { "Model": self.info.id } }); let path_str = self .core @@ -120,18 +116,15 @@ impl ModelVariant { Ok(PathBuf::from(path_str)) } - /// Load the variant into memory. - pub async fn load(&self) -> Result<()> { + pub(crate) async fn load(&self) -> Result<()> { self.model_load_manager.load(&self.info.id).await } - /// Unload the variant from memory. - pub async fn unload(&self) -> Result { + pub(crate) async fn unload(&self) -> Result { self.model_load_manager.unload(&self.info.id).await } - /// Remove the variant from the local cache. - pub async fn remove_from_cache(&self) -> Result { + pub(crate) async fn remove_from_cache(&self) -> Result { let params = json!({ "Params": { "Model": self.info.id } }); let result = self .core @@ -141,13 +134,11 @@ impl ModelVariant { Ok(result) } - /// Create a [`ChatClient`] bound to this variant. - pub fn create_chat_client(&self) -> ChatClient { + pub(crate) fn create_chat_client(&self) -> ChatClient { ChatClient::new(&self.info.id, Arc::clone(&self.core)) } - /// Create an [`AudioClient`] bound to this variant. - pub fn create_audio_client(&self) -> AudioClient { + pub(crate) fn create_audio_client(&self) -> AudioClient { AudioClient::new(&self.info.id, Arc::clone(&self.core)) } } diff --git a/sdk/rust/src/lib.rs b/sdk/rust/src/lib.rs index c12feef1..872a875c 100644 --- a/sdk/rust/src/lib.rs +++ b/sdk/rust/src/lib.rs @@ -6,8 +6,6 @@ mod catalog; mod configuration; mod error; mod foundry_local_manager; -mod model; -mod model_variant; mod types; pub(crate) mod detail; @@ -15,10 +13,9 @@ pub mod openai; pub use self::catalog::Catalog; pub use self::configuration::{FoundryLocalConfig, LogLevel, Logger}; +pub use self::detail::model::Model; pub use self::error::FoundryLocalError; pub use self::foundry_local_manager::FoundryLocalManager; -pub use self::model::Model; -pub use self::model_variant::ModelVariant; pub use self::types::{ ChatResponseFormat, ChatToolChoice, DeviceType, EpDownloadResult, EpInfo, ModelInfo, ModelSettings, Parameter, PromptTemplate, Runtime, diff --git a/sdk/rust/src/model.rs b/sdk/rust/src/model.rs deleted file mode 100644 index 9d08f9a5..00000000 --- a/sdk/rust/src/model.rs +++ /dev/null @@ -1,183 +0,0 @@ -//! High-level model abstraction that wraps one or more [`ModelVariant`]s -//! sharing the same alias. - -use std::fmt; -use std::path::PathBuf; -use std::sync::atomic::{AtomicUsize, Ordering::Relaxed}; -use std::sync::Arc; - -use crate::detail::core_interop::CoreInterop; -use crate::error::{FoundryLocalError, Result}; -use crate::model_variant::ModelVariant; -use crate::openai::AudioClient; -use crate::openai::ChatClient; - -/// A model groups one or more [`ModelVariant`]s that share the same alias. -/// -/// By default the variant that is already cached locally is selected. You -/// can override the selection with [`Model::select_variant`]. -pub struct Model { - alias: String, - core: Arc, - variants: Vec>, - selected_index: AtomicUsize, -} - -impl Clone for Model { - fn clone(&self) -> Self { - Self { - alias: self.alias.clone(), - core: Arc::clone(&self.core), - variants: self.variants.clone(), - selected_index: AtomicUsize::new(self.selected_index.load(Relaxed)), - } - } -} - -impl fmt::Debug for Model { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Model") - .field("alias", &self.alias()) - .field("id", &self.id()) - .field("variants_count", &self.variants.len()) - .field("selected_index", &self.selected_index.load(Relaxed)) - .finish() - } -} - -impl Model { - pub(crate) fn new(alias: String, core: Arc) -> Self { - Self { - alias, - core, - variants: Vec::new(), - selected_index: AtomicUsize::new(0), - } - } - - /// Add a variant. If the new variant is cached and the current selection - /// is not, the new variant becomes the selected one. - pub(crate) fn add_variant(&mut self, variant: Arc) { - self.variants.push(variant); - let new_idx = self.variants.len() - 1; - let current = self.selected_index.load(Relaxed); - - // Prefer a cached variant over a non-cached one. - if self.variants[new_idx].info().cached && !self.variants[current].info().cached { - self.selected_index.store(new_idx, Relaxed); - } - } - - /// Select a variant by its unique id. - pub fn select_variant(&self, id: &str) -> Result<()> { - match self.variants.iter().position(|v| v.id() == id) { - Some(pos) => { - self.selected_index.store(pos, Relaxed); - Ok(()) - } - None => { - let available: Vec<&str> = self.variants.iter().map(|v| v.id()).collect(); - Err(FoundryLocalError::ModelOperation { - reason: format!( - "Variant '{id}' not found for model '{}'. Available: {available:?}", - self.alias - ), - }) - } - } - } - - /// Returns a reference to the currently selected variant. - pub fn selected_variant(&self) -> &ModelVariant { - &self.variants[self.selected_index.load(Relaxed)] - } - - /// Returns all variants that belong to this model. - pub fn variants(&self) -> &[Arc] { - &self.variants - } - - /// Alias shared by all variants in this model. - pub fn alias(&self) -> &str { - &self.alias - } - - /// Unique identifier of the selected variant. - pub fn id(&self) -> &str { - self.selected_variant().id() - } - - /// Whether the selected variant is cached on disk. - pub async fn is_cached(&self) -> Result { - self.selected_variant().is_cached().await - } - - /// Whether the selected variant is loaded into memory. - pub async fn is_loaded(&self) -> Result { - self.selected_variant().is_loaded().await - } - - /// Context length (maximum input tokens) of the selected variant. - pub fn context_length(&self) -> Option { - self.selected_variant().info().context_length - } - - /// Input modalities of the selected variant (e.g. "text", "text,image"). - pub fn input_modalities(&self) -> Option<&str> { - self.selected_variant().info().input_modalities.as_deref() - } - - /// Output modalities of the selected variant (e.g. "text"). - pub fn output_modalities(&self) -> Option<&str> { - self.selected_variant().info().output_modalities.as_deref() - } - - /// Capabilities of the selected variant (e.g. "reasoning", "tool-calling"). - pub fn capabilities(&self) -> Option<&str> { - self.selected_variant().info().capabilities.as_deref() - } - - /// Whether the selected variant supports tool calling. - pub fn supports_tool_calling(&self) -> Option { - self.selected_variant().info().supports_tool_calling - } - - /// Download the selected variant. If `progress` is provided, it receives - /// human-readable progress strings as they arrive from the native core. - pub async fn download(&self, progress: Option) -> Result<()> - where - F: FnMut(&str) + Send + 'static, - { - self.selected_variant().download(progress).await - } - - /// Return the local file-system path of the selected variant. - pub async fn path(&self) -> Result { - self.selected_variant().path().await - } - - /// Load the selected variant into memory. - pub async fn load(&self) -> Result<()> { - self.selected_variant().load().await - } - - /// Unload the selected variant from memory. - pub async fn unload(&self) -> Result { - self.selected_variant().unload().await - } - - /// Remove the selected variant from the local cache. - pub async fn remove_from_cache(&self) -> Result { - self.selected_variant().remove_from_cache().await - } - - /// Create a [`ChatClient`] bound to the selected variant. - pub fn create_chat_client(&self) -> ChatClient { - ChatClient::new(self.id(), Arc::clone(&self.core)) - } - - /// Create an [`AudioClient`] bound to the selected variant. - pub fn create_audio_client(&self) -> AudioClient { - AudioClient::new(self.id(), Arc::clone(&self.core)) - } -} diff --git a/sdk/rust/tests/integration/model_test.rs b/sdk/rust/tests/integration/model_test.rs index d2b68b77..4e3b371b 100644 --- a/sdk/rust/tests/integration/model_test.rs +++ b/sdk/rust/tests/integration/model_test.rs @@ -111,11 +111,12 @@ async fn should_have_selected_variant_matching_id() { .await .expect("get_model failed"); - let selected = model.selected_variant(); + // The model's id() should return the selected variant's id + // info() delegates to the selected variant, so id() and info().id must agree assert_eq!( - selected.id(), model.id(), - "selected_variant().id() should match model.id()" + model.info().id, + "model.id() should match model.info().id (the selected variant's metadata)" ); } @@ -177,7 +178,7 @@ async fn should_select_variant_by_id() { ); // Restore the original variant so other tests sharing this - // Arc via the catalog are not affected. + // model via the catalog are not affected. model .select_variant(&original_id) .expect("restoring original variant should succeed");