124 lines
3.4 KiB
TypeScript
124 lines
3.4 KiB
TypeScript
"use client";
|
|
|
|
import { createContext, useContext, useState, useEffect, ReactNode } from "react";
|
|
import api from "@/shared/api/axios";
|
|
import { ApiResponse, unwrap } from "@/shared/api/types";
|
|
|
|
interface Task {
|
|
task_id: string;
|
|
status: string;
|
|
progress: number;
|
|
message: string;
|
|
download_url?: string;
|
|
}
|
|
|
|
interface TaskContextType {
|
|
currentTask: Task | null;
|
|
isGenerating: boolean;
|
|
startTask: (taskId: string) => void;
|
|
clearTask: () => void;
|
|
}
|
|
|
|
const TaskContext = createContext<TaskContextType | undefined>(undefined);
|
|
|
|
export function TaskProvider({ children }: { children: ReactNode }) {
|
|
const [currentTask, setCurrentTask] = useState<Task | null>(null);
|
|
const [isGenerating, setIsGenerating] = useState(false);
|
|
const [taskId, setTaskId] = useState<string | null>(null);
|
|
|
|
// 轮询任务状态
|
|
useEffect(() => {
|
|
if (!taskId) return;
|
|
|
|
const pollTask = async () => {
|
|
try {
|
|
const { data: res } = await api.get<ApiResponse<Task>>(`/api/videos/tasks/${taskId}`);
|
|
const task = unwrap(res);
|
|
setCurrentTask(task);
|
|
|
|
// 处理任务完成、失败或不存在的情况
|
|
if (task.status === "completed" || task.status === "failed" || task.status === "not_found") {
|
|
setIsGenerating(false);
|
|
setTaskId(null);
|
|
// 清除 localStorage
|
|
if (typeof window !== 'undefined') {
|
|
const keys = Object.keys(localStorage);
|
|
keys.forEach(key => {
|
|
if (key.includes('_current_task')) {
|
|
localStorage.removeItem(key);
|
|
}
|
|
});
|
|
}
|
|
}
|
|
} catch (error) {
|
|
console.error("轮询任务失败:", error);
|
|
setIsGenerating(false);
|
|
setTaskId(null);
|
|
// 清除 localStorage
|
|
if (typeof window !== 'undefined') {
|
|
const keys = Object.keys(localStorage);
|
|
keys.forEach(key => {
|
|
if (key.includes('_current_task')) {
|
|
localStorage.removeItem(key);
|
|
}
|
|
});
|
|
}
|
|
}
|
|
};
|
|
|
|
// 立即执行一次
|
|
pollTask();
|
|
|
|
// 每秒轮询
|
|
const interval = setInterval(pollTask, 1000);
|
|
|
|
return () => clearInterval(interval);
|
|
}, [taskId]);
|
|
|
|
// 页面加载时恢复任务
|
|
useEffect(() => {
|
|
if (typeof window === 'undefined') return;
|
|
|
|
// 查找所有可能的任务ID
|
|
const keys = Object.keys(localStorage);
|
|
const taskKey = keys.find(key => key.includes('_current_task'));
|
|
|
|
if (taskKey) {
|
|
const savedTaskId = localStorage.getItem(taskKey);
|
|
if (savedTaskId) {
|
|
console.log("[TaskContext] 恢复任务:", savedTaskId);
|
|
// eslint-disable-next-line react-hooks/set-state-in-effect
|
|
setTaskId(savedTaskId);
|
|
// eslint-disable-next-line react-hooks/set-state-in-effect
|
|
setIsGenerating(true);
|
|
}
|
|
}
|
|
}, []);
|
|
|
|
const startTask = (newTaskId: string) => {
|
|
setTaskId(newTaskId);
|
|
setIsGenerating(true);
|
|
setCurrentTask(null);
|
|
};
|
|
|
|
const clearTask = () => {
|
|
setTaskId(null);
|
|
setIsGenerating(false);
|
|
setCurrentTask(null);
|
|
};
|
|
|
|
return (
|
|
<TaskContext.Provider value={{ currentTask, isGenerating, startTask, clearTask }}>
|
|
{children}
|
|
</TaskContext.Provider>
|
|
);
|
|
}
|
|
|
|
export function useTask() {
|
|
const context = useContext(TaskContext);
|
|
if (context === undefined) {
|
|
throw new Error("useTask must be used within a TaskProvider");
|
|
}
|
|
return context;
|
|
}
|