123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334 |
- using System;
- using System.Collections.Generic;
- using System.Diagnostics;
- using System.IO;
- using System.Linq;
- using System.Text;
- using System.Threading;
- using System.Threading.Tasks;
- using LLM.Editor.Analysis;
- using LLM.Editor.Api;
- using LLM.Editor.Core;
- using LLM.Editor.Data;
- using LLM.Editor.Helper;
- using UnityEditor;
- using UnityEngine.Networking;
- using Debug = UnityEngine.Debug;
- using Object = UnityEngine.Object;
- namespace LLM.Editor.Client
- {
- /// <summary>
- /// The client responsible for communicating with the live Google Gemini API.
- /// </summary>
- public class GeminiApiClient : ILlmApiClient
- {
- [Serializable]
- private class CommandResponse { public List<CommandData> commands; }
- private Settings.MCPSettings _settings;
- private string _systemPrompt;
- private bool _isInitialized;
- private string _lastUserPrompt;
- public Task<bool> Initialize()
- {
- if (_isInitialized) return Task.FromResult(true);
- LoadSettings();
- LoadSystemPrompt();
- if (!_settings || string.IsNullOrEmpty(_systemPrompt))
- {
- Debug.LogError("[GeminiApiClient] Initialization failed. Check if MCPSettings and MCP_SystemPrompt.txt exist.");
- return Task.FromResult(false);
- }
- _isInitialized = true;
- return Task.FromResult(true);
- }
- public async Task SendPrompt(string prompt, List<Object> stagedContext, CancellationToken cancellationToken = default)
- {
- if (!await Initialize()) return;
- var authToken = GetAuthToken();
- if (string.IsNullOrEmpty(authToken))
- {
- Debug.LogError("[GeminiApiClient] Failed to get authentication token.");
- return;
- }
-
- _lastUserPrompt = prompt;
- var fullPrompt = BuildInitialPrompt(prompt, stagedContext);
- Debug.Log("[GeminiApiClient] Sending prompt: \n" + fullPrompt);
- var chatHistory = SessionManager.LoadChatHistory();
- chatHistory.Add(new ChatEntry { role = "user", content = fullPrompt });
-
- await SendApiRequest(chatHistory, authToken, cancellationToken);
- }
-
- public async Task SendFollowUp(string detailedContext, CancellationToken cancellationToken = default)
- {
- if (!await Initialize()) return;
- var authToken = GetAuthToken();
- if (string.IsNullOrEmpty(authToken))
- {
- Debug.LogError("[GeminiApiClient] Failed to get authentication token.");
- return;
- }
-
- _lastUserPrompt = detailedContext;
- var chatHistory = SessionManager.LoadChatHistory();
- chatHistory.Add(new ChatEntry { role = "user", content = detailedContext });
-
- await SendApiRequest(chatHistory, authToken, cancellationToken);
- }
- private static string BuildInitialPrompt(string prompt, List<Object> stagedContext)
- {
- var promptBuilder = new StringBuilder();
- promptBuilder.AppendLine("User Request:");
- promptBuilder.AppendLine(prompt);
- if (stagedContext == null || stagedContext.All(o => o == null)) return promptBuilder.ToString();
-
- var tier1Summary = ContextBuilder.BuildTier1Summary(stagedContext);
- promptBuilder.AppendLine("\n--- Staged Context ---");
- promptBuilder.AppendLine(tier1Summary);
- promptBuilder.AppendLine("--- End Context ---");
- return promptBuilder.ToString();
- }
- private async Task SendApiRequest(List<ChatEntry> chatHistory, string authToken, CancellationToken cancellationToken)
- {
- var apiRequest = await BuildApiRequest(chatHistory, cancellationToken);
- if (apiRequest == null) return;
- var responseJson = await ExecuteWebRequest(apiRequest, authToken, cancellationToken);
- if (string.IsNullOrEmpty(responseJson)) return;
-
- ProcessApiResponse(responseJson, chatHistory);
- }
- private async Task<ApiRequest> BuildApiRequest(List<ChatEntry> chatHistory, CancellationToken cancellationToken)
- {
- string systemPromptWithContext;
- List<Content> apiContents;
- if (_settings.useRagMemory)
- {
- var relevantMemories = await MemoryRetriever.GetRelevantMemories(_lastUserPrompt, cancellationToken);
- if (cancellationToken.IsCancellationRequested) return null;
-
- systemPromptWithContext = await BuildAugmentedSystemPrompt(relevantMemories);
-
- var lastUserEntry = chatHistory.LastOrDefault(e => e.role == "user");
- apiContents = new List<Content>();
- if (lastUserEntry != null)
- {
- apiContents.Add(new Content
- {
- role = "user",
- parts = new List<Part> { new() { text = lastUserEntry.content } }
- });
- }
- }
- else
- {
- systemPromptWithContext = await BuildAugmentedSystemPrompt(null);
- apiContents = chatHistory.Select(entry => new Content
- {
- role = entry.role == "assistant" ? "model" : entry.role,
- parts = new List<Part> { new() { text = entry.content } }
- }).ToList();
- }
- return new ApiRequest
- {
- system_instruction = new SystemInstruction { parts = new List<Part> { new() { text = systemPromptWithContext } } },
- contents = apiContents
- };
- }
- private async Task<string> ExecuteWebRequest(ApiRequest apiRequest, string authToken, CancellationToken cancellationToken)
- {
- var region = _settings.gcpRegion == "global" ? string.Empty : $"{_settings.gcpRegion}-";
- var url = $"https://{region}aiplatform.googleapis.com/v1/projects/{_settings.gcpProjectId}/locations/{_settings.gcpRegion}/publishers/google/models/{_settings.modelName}:generateContent";
-
- var jsonPayload = apiRequest.ToJson();
- using var request = new UnityWebRequest(url, "POST");
- var bodyRaw = Encoding.UTF8.GetBytes(jsonPayload);
- request.uploadHandler = new UploadHandlerRaw(bodyRaw);
- request.downloadHandler = new DownloadHandlerBuffer();
- request.SetRequestHeader("Content-Type", "application/json");
- request.SetRequestHeader("Authorization", $"Bearer {authToken}");
- var operation = request.SendWebRequest();
- while (!operation.isDone)
- {
- if (cancellationToken.IsCancellationRequested)
- {
- request.Abort();
- Debug.Log("[GeminiApiClient] API request was cancelled.");
- return null;
- }
- await Task.Yield();
- }
- if (request.result == UnityWebRequest.Result.Success)
- {
- Debug.Log($"[GeminiApiClient] API Response: \n{request.downloadHandler.text}");
- return request.downloadHandler.text;
- }
- Debug.LogError($"[GeminiApiClient] API Error: {request.error}\n{request.downloadHandler.text}");
- return null;
- }
- private void ProcessApiResponse(string responseJson, List<ChatEntry> chatHistory)
- {
- var apiResponse = responseJson.FromJson<ApiResponse>();
- if (apiResponse.candidates == null || !apiResponse.candidates.Any()) return;
-
- var rawText = apiResponse.candidates[0].content.parts[0].text;
- var commandJson = ExtractJsonFromString(rawText);
- if (string.IsNullOrEmpty(commandJson))
- {
- Debug.LogError($"[GeminiApiClient] Could not extract valid JSON from the LLM response text: \n{rawText}");
- return;
- }
-
- chatHistory.Add(new ChatEntry { role = "model", content = commandJson });
- SessionManager.SaveChatHistory(chatHistory);
-
- CommandResponse commandResponse;
- try
- {
- commandResponse = commandJson.FromJson<CommandResponse>();
- }
- catch (Exception exception)
- {
- Debug.LogException(exception);
- return;
- }
-
- if (commandResponse is { commands: not null })
- {
- CommandExecutor.SetQueue(commandResponse.commands, _lastUserPrompt);
- }
- else
- {
- var singleCommand = commandJson.FromJson<CommandData>();
- if (singleCommand != null)
- {
- CommandExecutor.SetQueue(new List<CommandData> { singleCommand }, _lastUserPrompt);
- }
- else
- {
- Debug.LogError($"[GeminiApiClient] Failed to parse command structure from LLM response text: {commandJson}");
- }
- }
- }
- private void LoadSettings()
- {
- if (_settings) return;
- var guids = AssetDatabase.FindAssets("t:MCPSettings");
- if (guids.Length == 0) return;
- var path = AssetDatabase.GUIDToAssetPath(guids[0]);
- _settings = AssetDatabase.LoadAssetAtPath<Settings.MCPSettings>(path);
- }
- private void LoadSystemPrompt()
- {
- if (!string.IsNullOrEmpty(_systemPrompt)) return;
- var guids = AssetDatabase.FindAssets("MCP_SystemPrompt");
- if (guids.Length == 0) return;
- var path = AssetDatabase.GUIDToAssetPath(guids[0]);
- _systemPrompt = File.ReadAllText(path);
- }
-
- private static string ExtractJsonFromString(string text)
- {
- if (string.IsNullOrWhiteSpace(text)) return null;
- var firstBrace = text.IndexOf('{');
- var lastBrace = text.LastIndexOf('}');
- if (firstBrace == -1 || lastBrace == -1 || lastBrace < firstBrace)
- {
- return null;
- }
- return text.Substring(firstBrace, lastBrace - firstBrace + 1);
- }
- public string GetAuthToken()
- {
- if (!_settings) return null;
- try
- {
- var startInfo = new ProcessStartInfo
- {
- FileName = _settings.gcloudPath,
- Arguments = "auth print-access-token",
- RedirectStandardOutput = true,
- RedirectStandardError = true,
- UseShellExecute = false,
- CreateNoWindow = true
- };
- using var process = Process.Start(startInfo);
- if (process == null) return null;
- var accessToken = process.StandardOutput.ReadToEnd().Trim();
- var error = process.StandardError.ReadToEnd();
- process.WaitForExit();
- if (process.ExitCode == 0) return accessToken;
- Debug.LogError($"[GeminiApiClient] gcloud auth error: {error}");
- return null;
- }
- catch (Exception e)
- {
- Debug.LogError($"[GeminiApiClient] Exception while getting auth token: {e.Message}");
- return null;
- }
- }
- private Task<string> BuildAugmentedSystemPrompt(List<InteractionRecord> memories)
- {
- var systemPromptBuilder = new StringBuilder(_systemPrompt);
- if (memories != null && memories.Any())
- {
- systemPromptBuilder.AppendLine("\n\n## RELEVANT EXAMPLES FROM YOUR MEMORY");
- systemPromptBuilder.AppendLine("Here are some past attempts that might be relevant. Learn from them.");
-
- foreach (var memory in memories)
- {
- systemPromptBuilder.AppendLine("\n---");
- systemPromptBuilder.AppendLine($"PAST TASK: User asked '{memory.UserPrompt}'");
- systemPromptBuilder.AppendLine($"YOUR ACTION: You used the command '{memory.LLMResponse.commandName}' with parameters: {memory.LLMResponse.jsonData}");
- systemPromptBuilder.AppendLine($"OUTCOME: {memory.Outcome}");
- if (memory.Outcome == Commands.CommandOutcome.Error)
- {
- systemPromptBuilder.AppendLine($"FEEDBACK: {memory.Feedback}");
- }
- systemPromptBuilder.AppendLine("---");
- }
- }
- var workingContext = SessionManager.LoadWorkingContext();
- systemPromptBuilder.AppendLine("\n\n## CURRENT WORKING CONTEXT");
- systemPromptBuilder.AppendLine("This is your short-term memory. Use the data here before asking for it again.");
- systemPromptBuilder.AppendLine("```json");
- systemPromptBuilder.AppendLine(workingContext.ToString(Newtonsoft.Json.Formatting.Indented));
- systemPromptBuilder.AppendLine("```");
- return Task.FromResult(systemPromptBuilder.ToString());
- }
- }
- }
|