diff --git a/web/src/components/classification/ClassificationModelWizardDialog.tsx b/web/src/components/classification/ClassificationModelWizardDialog.tsx index 621c9ea90..eaba57d08 100644 --- a/web/src/components/classification/ClassificationModelWizardDialog.tsx +++ b/web/src/components/classification/ClassificationModelWizardDialog.tsx @@ -7,7 +7,8 @@ import { DialogHeader, DialogTitle, } from "../ui/dialog"; -import { useState } from "react"; +import { useReducer } from "react"; +import Step1NameAndDefine, { Step1FormData } from "./wizard/Step1NameAndDefine"; const STEPS = [ "classificationWizard.steps.nameAndDefine", @@ -20,21 +21,74 @@ type ClassificationModelWizardDialogProps = { open: boolean; onClose: () => void; }; + +type WizardState = { + currentStep: number; + step1Data?: Step1FormData; + // Future steps can be added here + // step2Data?: Step2FormData; + // step3Data?: Step3FormData; +}; + +type WizardAction = + | { type: "NEXT_STEP"; payload?: Partial } + | { type: "PREVIOUS_STEP" } + | { type: "SET_STEP_1"; payload: Step1FormData } + | { type: "RESET" }; + +const initialState: WizardState = { + currentStep: 0, +}; + +function wizardReducer(state: WizardState, action: WizardAction): WizardState { + switch (action.type) { + case "SET_STEP_1": + return { + ...state, + step1Data: action.payload, + currentStep: 1, + }; + case "NEXT_STEP": + return { + ...state, + ...action.payload, + currentStep: state.currentStep + 1, + }; + case "PREVIOUS_STEP": + return { + ...state, + currentStep: Math.max(0, state.currentStep - 1), + }; + case "RESET": + return initialState; + default: + return state; + } +} + export default function ClassificationModelWizardDialog({ open, onClose, }: ClassificationModelWizardDialogProps) { const { t } = useTranslation(["views/classificationModel"]); - // step management - const [currentStep, _] = useState(0); + const [wizardState, dispatch] = useReducer(wizardReducer, initialState); + + const handleStep1Next = (data: Step1FormData) => { + dispatch({ type: "SET_STEP_1", payload: data }); + }; + + const handleCancel = () => { + dispatch({ type: "RESET" }); + onClose(); + }; return ( { if (!open) { - onClose; + handleCancel(); } }} > @@ -46,19 +100,25 @@ export default function ClassificationModelWizardDialog({ > {t("wizard.title")} - {currentStep === 0 && ( + {wizardState.currentStep === 0 && ( {t("wizard.description")} )}
-
+ {wizardState.currentStep === 0 && ( + + )}
diff --git a/web/src/components/classification/wizard/Step1NameAndDefine.tsx b/web/src/components/classification/wizard/Step1NameAndDefine.tsx new file mode 100644 index 000000000..37ffa49e9 --- /dev/null +++ b/web/src/components/classification/wizard/Step1NameAndDefine.tsx @@ -0,0 +1,323 @@ +import { Button } from "@/components/ui/button"; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, +} from "@/components/ui/form"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { RadioGroup, RadioGroupItem } from "@/components/ui/radio-group"; +import { useForm } from "react-hook-form"; +import { zodResolver } from "@hookform/resolvers/zod"; +import { z } from "zod"; +import { LuX } from "react-icons/lu"; +import { MdAddBox } from "react-icons/md"; + +export type ModelType = "state" | "object"; +export type ObjectClassificationType = "sub_label" | "attribute"; + +export type Step1FormData = { + modelName: string; + modelType: ModelType; + objectType?: ObjectClassificationType; + classes: string[]; +}; + +type Step1NameAndDefineProps = { + initialData?: Partial; + onNext: (data: Step1FormData) => void; + onCancel: () => void; +}; + +export default function Step1NameAndDefine({ + initialData, + onNext, + onCancel, +}: Step1NameAndDefineProps) { + const step1FormData = z + .object({ + modelName: z + .string() + .min(1, "Model name is required") + .max(64, "Model name must be 64 characters or less") + .refine((value) => !/^\d+$/.test(value), { + message: "Model name cannot contain only numbers", + }), + modelType: z.enum(["state", "object"]), + objectType: z.enum(["sub_label", "attribute"]).optional(), + classes: z + .array(z.string()) + .min(1, "At least one class field is required") + .refine( + (classes) => { + const nonEmpty = classes.filter((c) => c.trim().length > 0); + return nonEmpty.length >= 1; + }, + { message: "At least 1 class is required" }, + ) + .refine( + (classes) => { + const nonEmpty = classes.filter((c) => c.trim().length > 0); + const unique = new Set(nonEmpty.map((c) => c.toLowerCase())); + return unique.size === nonEmpty.length; + }, + { message: "Class names must be unique" }, + ), + }) + .refine( + (data) => { + // State models require at least 2 classes + if (data.modelType === "state") { + const nonEmpty = data.classes.filter((c) => c.trim().length > 0); + return nonEmpty.length >= 2; + } + return true; + }, + { + message: "State models require at least 2 classes", + path: ["classes"], + }, + ) + .refine( + (data) => { + // Object models require objectType to be selected + if (data.modelType === "object") { + return data.objectType !== undefined; + } + return true; + }, + { + message: "Please select a classification type", + path: ["objectType"], + }, + ); + + const form = useForm>({ + resolver: zodResolver(step1FormData), + defaultValues: { + modelName: initialData?.modelName || "", + modelType: initialData?.modelType || "state", + objectType: initialData?.objectType || "sub_label", + classes: initialData?.classes?.length ? initialData.classes : [""], + }, + mode: "onChange", + }); + + const watchedClasses = form.watch("classes"); + const watchedModelType = form.watch("modelType"); + const watchedObjectType = form.watch("objectType"); + + const handleAddClass = () => { + const currentClasses = form.getValues("classes"); + form.setValue("classes", [...currentClasses, ""], { shouldValidate: true }); + }; + + const handleRemoveClass = (index: number) => { + const currentClasses = form.getValues("classes"); + const newClasses = currentClasses.filter((_, i) => i !== index); + + // Ensure at least one field remains (even if empty) + if (newClasses.length === 0) { + form.setValue("classes", [""], { shouldValidate: true }); + } else { + form.setValue("classes", newClasses, { shouldValidate: true }); + } + }; + + const onSubmit = (data: z.infer) => { + // Filter out empty classes + const filteredClasses = data.classes.filter((c) => c.trim().length > 0); + onNext({ + ...data, + classes: filteredClasses, + }); + }; + + return ( +
+
+ + ( + + Name + + + + + + )} + /> + + ( + + Type + + +
+ + +
+
+ + +
+
+
+ +
+ )} + /> + + {watchedModelType === "object" && ( + ( + + Classification Type + + +
+ + +
+
+ + +
+
+
+ +
+ )} + /> + )} + +
+
+ Classes + +
+
+ {watchedClasses.map((_, index) => ( + ( + + +
+ + {watchedClasses.length > 1 && ( + + )} +
+
+
+ )} + /> + ))} +
+ {form.formState.errors.classes && ( +

+ {form.formState.errors.classes.message} +

+ )} +
+ + + +
+ + +
+
+ ); +} diff --git a/web/src/views/classification/ModelSelectionView.tsx b/web/src/views/classification/ModelSelectionView.tsx index 6d6287b4d..6641416b3 100644 --- a/web/src/views/classification/ModelSelectionView.tsx +++ b/web/src/views/classification/ModelSelectionView.tsx @@ -139,13 +139,17 @@ function ModelCard({ config, onClick }: ModelCardProps) { }>(`classification/${config.name}/dataset`, { revalidateOnFocus: false }); const coverImage = useMemo(() => { - if (!dataset?.length) { + if (!dataset) { return undefined; } const keys = Object.keys(dataset).filter((key) => key != "none"); const selectedKey = keys[0]; + if (!dataset[selectedKey]) { + return undefined; + } + return { name: selectedKey, img: dataset[selectedKey][0],