summaryrefslogtreecommitdiff
path: root/shitoutcode/__main__.py
blob: 09ca715cb980b6b8859071bd5afc33e84f989b2f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import random
import sys
from shitoutcode import Env, LLMAPIException, SourceFile, gen_rand_srccode

ARGV0 = "shitoutcode"

def loadEnv_from_env () -> Env:
	def wrap_default (val: str, t: type, dv = None):
		if val is None:
			return dv
		return t(val)

	ret = Env()
	ret.model = wrap_default(os.getenv("LLM_MODEL"), str, "gpt-4o-mini")
	ret.output_tokens_min = wrap_default(os.getenv("OUT_TOKENS_MIN"), int, ret.output_tokens_min)
	ret.output_tokens_max = wrap_default(os.getenv("OUT_TOKENS_MAX"), int, ret.output_tokens_max)
	ret.temp = wrap_default(os.getenv("LLM_TEMP"), float)
	ret.seed = wrap_default(os.getenv("LLM_SEED"), int)
	ret.max_tokens = wrap_default(os.getenv("LLM_SEED"), int, ret.max_tokens)

	return ret

def getPathLine (x: SourceFile) -> str:
	if x.category:
		return x.category + os.path.sep + x.name
	return x.name

def pickRandLang () -> str:
	thelist = [
		"C",
		"C++",
		"Javascript",
		"Java"
	]
	r = random.randint(0, len(thelist) - 1)

	return thelist[r]

env = loadEnv_from_env()
lang = sys.argv[1]  if len(sys.argv) > 1 else pickRandLang()
extra_prompt = sys.argv[2] if len(sys.argv) > 2 else None

try:
	result = gen_rand_srccode(lang, extra_prompt, env)
	if result is SourceFile:
		print(getPathLine(result))
		print()
		print(result.contents)
	else:
		print(result)
except LLMAPIException as e:
	sys.stderr.write(ARGV0 + ": model gave up: " + str(e) + os.linesep)
	exit(1)