75 lines
2.4 KiB
Dart
75 lines
2.4 KiB
Dart
import 'dart:convert';
|
|
|
|
import 'package:bogui/global.dart';
|
|
import 'package:bogui/load_env.dart';
|
|
import 'package:flutter/foundation.dart';
|
|
import 'package:flutter/material.dart';
|
|
import 'package:flutter_riverpod/flutter_riverpod.dart';
|
|
import 'package:openai_gpt3_api/completion.dart';
|
|
import 'package:openai_gpt3_api/openai_gpt3_api.dart';
|
|
|
|
class OpenAI {
|
|
late GPT3 gpt;
|
|
|
|
final prompt =
|
|
StateProvider<TextEditingController>((ref) => TextEditingController());
|
|
final isLoading = StateProvider<bool>((ref) => false);
|
|
final temperature = StateProvider<double>((ref) => 0.7);
|
|
|
|
Future init() async {
|
|
final env = await loadEnv();
|
|
String openaiSecretFile = '';
|
|
if (env.containsKey('OPENAPI_SECRET_KEY')) {
|
|
openaiSecretFile = env['OPENAPI_SECRET_KEY']!;
|
|
}
|
|
const openaiSecretEnv = String.fromEnvironment('OPENAPI_SECRET_KEY');
|
|
gpt = GPT3(kDebugMode ? openaiSecretFile : openaiSecretEnv);
|
|
}
|
|
|
|
Future<String> completionEasy(WidgetRef ref) async {
|
|
if (ref.read(isLoading) || ref.read(prompt).text.length < 2) return '';
|
|
ref.read(isLoading.notifier).state = true;
|
|
late CompletionApiResult anwser;
|
|
try {
|
|
anwser = await gpt.completion(ref.read(prompt).text,
|
|
maxTokens: 250,
|
|
engine: Engine.davinci3,
|
|
temperature: ref.read(temperature),
|
|
echo: false,
|
|
stream: false);
|
|
} catch (e) {
|
|
log.d(e);
|
|
ref.read(prompt).text +=
|
|
"\nJe n'ai pas la bonne clé API secret pour OpenAI, connard.\n$e";
|
|
ref.read(isLoading.notifier).state = false;
|
|
return "\nJe n'ai pas la bonne clé API secret pour OpenAI, connard.\n$e";
|
|
}
|
|
|
|
String anwserString = '';
|
|
for (final choice in anwser.choices) {
|
|
anwserString += choice.text;
|
|
}
|
|
|
|
ref.read(prompt).text += utf8.decode(anwserString.codeUnits);
|
|
ref.read(isLoading.notifier).state = false;
|
|
ref.read(prompt).selection = TextSelection.fromPosition(
|
|
TextPosition(offset: ref.read(prompt).text.length));
|
|
return anwserString;
|
|
}
|
|
}
|
|
|
|
double truncateDouble(double val, int decimals) {
|
|
String valString = val.toString();
|
|
int dotIndex = valString.indexOf('.');
|
|
|
|
// not enough decimals
|
|
int totalDecimals = valString.length - dotIndex - 1;
|
|
if (totalDecimals < decimals) {
|
|
decimals = totalDecimals;
|
|
}
|
|
|
|
valString = valString.substring(0, dotIndex + decimals + 1);
|
|
|
|
return double.parse(valString);
|
|
}
|